{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cfb832c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/wenet\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'/workspace/wenet'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%cd /workspace/wenet/\n",
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "62277538",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import argparse\n",
    "import copy\n",
    "import logging\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "import torch.optim as optim\n",
    "import yaml\n",
    "from tensorboardX import SummaryWriter\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from wenet.dataset.dataset import AudioDataset, CollateFunc\n",
    "from wenet.transformer.asr_model import init_asr_model\n",
    "from wenet.utils.checkpoint import load_checkpoint, save_checkpoint\n",
    "from wenet.utils.executor import Executor\n",
    "from wenet.utils.scheduler import WarmupLR\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2f6ea33a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'config': 'examples/aishell/s0/conf/train_conformer.yaml', 'train_data': 'examples/aishell/s0/raw_wav/train/format.data', 'cv_data': 'examples/aishell/s0/raw_wav/dev/format.data', 'gpu': -1, 'model_dir': None, 'checkpoint': None, 'tensorboard_dir': 'tensorboard', 'rank': 0, 'world_size': -1, 'dist_backend': 'nccl', 'init_method': None, 'num_workers': 0, 'pin_memory': False, 'cmvn': 'examples/aishell/s0/raw_wav/train/global_cmvn'}\n"
     ]
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser(description='training your network')\n",
    "parser.add_argument('--config', default=\"examples/aishell/s0/conf/train_conformer.yaml\",  help='config file')\n",
    "parser.add_argument('--train_data', default=\"examples/aishell/s0/raw_wav/train/format.data\",  help='train data file')\n",
    "parser.add_argument('--cv_data', default=\"examples/aishell/s0/raw_wav/dev/format.data\",  help='cv data file')\n",
    "parser.add_argument('--gpu',\n",
    "                    type=int,\n",
    "                    default=-1,\n",
    "                    help='gpu id for this local rank, -1 for cpu')\n",
    "parser.add_argument('--model_dir' , help='save model dir')\n",
    "parser.add_argument('--checkpoint', help='checkpoint model')\n",
    "parser.add_argument('--tensorboard_dir',\n",
    "                    default='tensorboard',\n",
    "                    help='tensorboard log dir')\n",
    "parser.add_argument('--ddp.rank',\n",
    "                    dest='rank',\n",
    "                    default=0,\n",
    "                    type=int,\n",
    "                    help='global rank for distributed training')\n",
    "parser.add_argument('--ddp.world_size',\n",
    "                    dest='world_size',\n",
    "                    default=-1,\n",
    "                    type=int,\n",
    "                    help='''number of total processes/gpus for\n",
    "                    distributed training''')\n",
    "parser.add_argument('--ddp.dist_backend',\n",
    "                    dest='dist_backend',\n",
    "                    default='nccl',\n",
    "                    choices=['nccl', 'gloo'],\n",
    "                    help='distributed backend')\n",
    "parser.add_argument('--ddp.init_method',\n",
    "                    dest='init_method',\n",
    "                    default=None,\n",
    "                    help='ddp init method')\n",
    "parser.add_argument('--num_workers',\n",
    "                    default=0,\n",
    "                    type=int,\n",
    "                    help='num of subprocess workers for reading')\n",
    "parser.add_argument('--pin_memory',\n",
    "                    action='store_true',\n",
    "                    default=False,\n",
    "                    help='Use pinned memory buffers used for reading')\n",
    "parser.add_argument('--cmvn', default=\"examples/aishell/s0/raw_wav/train/global_cmvn\", help='global cmvn file')\n",
    "\n",
    "args = parser.parse_args([])\n",
    "print(vars(args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f5d6af9b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(checkpoint=None, cmvn='examples/aishell/s0/raw_wav/train/global_cmvn', config='examples/aishell/s0/conf/train_conformer.yaml', cv_data='examples/aishell/s0/raw_wav/dev/format.data', dist_backend='nccl', gpu=-1, init_method=None, model_dir=None, num_workers=0, pin_memory=False, rank=0, tensorboard_dir='tensorboard', train_data='examples/aishell/s0/raw_wav/train/format.data', world_size=-1)\n"
     ]
    }
   ],
   "source": [
    "# Set random seed\n",
    "torch.manual_seed(777)\n",
    "print(args)\n",
    "with open(args.config, 'r') as fin:\n",
    "    configs = yaml.load(fin, Loader=yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "264bd353",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7507 batches\n",
      "896\n"
     ]
    }
   ],
   "source": [
    "raw_wav = configs['raw_wav']\n",
    "\n",
    "train_collate_func = CollateFunc(**configs['collate_conf'],\n",
    "                                 raw_wav=raw_wav)\n",
    "\n",
    "cv_collate_conf = copy.deepcopy(configs['collate_conf'])\n",
    "# no augmenation on cv set\n",
    "cv_collate_conf['spec_aug'] = False\n",
    "cv_collate_conf['spec_sub'] = False\n",
    "if raw_wav:\n",
    "    cv_collate_conf['feature_dither'] = 0.0\n",
    "    cv_collate_conf['speed_perturb'] = False\n",
    "    cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0\n",
    "cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)\n",
    "\n",
    "dataset_conf = configs.get('dataset_conf', {})\n",
    "train_dataset = AudioDataset(args.train_data,\n",
    "                             **dataset_conf,\n",
    "                             raw_wav=raw_wav)\n",
    "cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)\n",
    "# 120098 data/train/wav.scp\n",
    "print(len(train_dataset), 'batches')\n",
    "# 14326 data/dev/wav.scp\n",
    "print(len(cv_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "88863d3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "896\n"
     ]
    }
   ],
   "source": [
    "train_sampler = None\n",
    "cv_sampler = None\n",
    "train_data_loader = DataLoader(train_dataset,\n",
    "                               collate_fn=train_collate_func,\n",
    "                               sampler=train_sampler,\n",
    "                               #shuffle=(train_sampler is None),\n",
    "                               shuffle=False,\n",
    "                               pin_memory=args.pin_memory,\n",
    "                               batch_size=1,\n",
    "                               num_workers=args.num_workers)\n",
    "cv_data_loader = DataLoader(cv_dataset,\n",
    "                            collate_fn=cv_collate_func,\n",
    "                            sampler=cv_sampler,\n",
    "                            shuffle=False,\n",
    "                            batch_size=1,\n",
    "                            pin_memory=args.pin_memory,\n",
    "                            num_workers=args.num_workers)\n",
    "print(len(cv_data_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "10d5acd4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4233 vocab\n",
      "80 feat dim\n"
     ]
    }
   ],
   "source": [
    "if raw_wav:\n",
    "    input_dim = configs['collate_conf']['feature_extraction_conf'][\n",
    "        'mel_bins']\n",
    "else:\n",
    "    input_dim = train_dataset.input_dim\n",
    "vocab_size = train_dataset.output_dim\n",
    "print(vocab_size, 'vocab')\n",
    "print(input_dim , 'feat dim')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0380ef5a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "examples/aishell/s0/raw_wav/train/global_cmvn\n"
     ]
    }
   ],
   "source": [
    "# Save configs to model_dir/train.yaml for inference and export\n",
    "configs['input_dim'] = input_dim\n",
    "configs['output_dim'] = vocab_size\n",
    "configs['cmvn_file'] = args.cmvn\n",
    "configs['is_json_cmvn'] = raw_wav\n",
    "print(args.cmvn)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "15ebf2bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(80,)\n",
      "(80,)\n",
      "[ 9.87176362  9.93891555 10.23818678 10.85971412 11.68652649 12.2548801\n",
      " 12.65768161 12.86138996 12.80733912 12.56625574 12.32007066 12.13879205\n",
      " 12.31318868 12.55255216 12.61223855 12.56974526 12.38972728 12.14383338\n",
      " 12.09285066 11.79395822 11.62259065 11.9263303  11.8154422  11.95122567\n",
      " 11.83180553 11.88788759 11.79014437 11.88072035 11.90005711 11.97348142\n",
      " 12.00982189 12.00881339 12.02619706 12.10479646 12.21555081 12.34399304\n",
      " 12.45014401 12.4966879  12.48653775 12.3550783  12.39291732 12.2553737\n",
      " 12.26496277 12.25314244 12.32545763 12.43359839 12.54867439 12.6763342\n",
      " 12.80920698 12.92934681 12.96115138 12.96883353 12.99593057 13.04728142\n",
      " 13.0588804  13.05737948 12.99921175 12.93402238 12.87429219 12.71652995\n",
      " 12.48942004 12.27478385 12.26163069 12.28631891 12.31956049 12.4229073\n",
      " 12.51480191 12.5785164  12.64719411 12.73762568 12.80017069 12.86872766\n",
      " 12.96666856 13.06478583 13.15915908 13.27284306 13.31081821 13.23904279\n",
      " 12.87936075 11.18310185]\n",
      "[0.61219383 0.49700994 0.33439025 0.31503119 0.29640823 0.28411759\n",
      " 0.26972922 0.25610475 0.24632936 0.24610228 0.24733299 0.24426536\n",
      " 0.23751781 0.22987273 0.22659963 0.2268427  0.23059031 0.23420722\n",
      " 0.23771761 0.2411352  0.24404673 0.24557175 0.24724932 0.25055198\n",
      " 0.25482755 0.2602407  0.26363878 0.26503898 0.2648467  0.26435072\n",
      " 0.26353625 0.26364794 0.26411054 0.26339948 0.26212082 0.26146597\n",
      " 0.26196556 0.26365859 0.26592959 0.26963884 0.27392766 0.27818809\n",
      " 0.28313664 0.2863325  0.28713431 0.28649323 0.28636648 0.2867843\n",
      " 0.28635904 0.28562022 0.28492711 0.28429201 0.28402977 0.28401045\n",
      " 0.28560797 0.28728033 0.28969549 0.29351627 0.29826453 0.30572631\n",
      " 0.31811682 0.32887739 0.33288219 0.33326245 0.33014147 0.32403202\n",
      " 0.31903576 0.31316258 0.30741037 0.30370692 0.30204833 0.30049064\n",
      " 0.29901079 0.29824511 0.29812308 0.29753329 0.29779342 0.30175296\n",
      " 0.30955538 0.32904205]\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import math\n",
    "import numpy as np\n",
    "def _load_json_cmvn(json_cmvn_file):\n",
    "    \"\"\" Load the json format cmvn stats file and calculate cmvn\n",
    "\n",
    "    Args:\n",
    "        json_cmvn_file: cmvn stats file in json format\n",
    "\n",
    "    Returns:\n",
    "        a numpy array of [means, vars]\n",
    "    \"\"\"\n",
    "    with open(json_cmvn_file) as f:\n",
    "        cmvn_stats = json.load(f)\n",
    "\n",
    "    means = cmvn_stats['mean_stat']\n",
    "    variance = cmvn_stats['var_stat']\n",
    "    count = cmvn_stats['frame_num']\n",
    "    for i in range(len(means)):\n",
    "        means[i] /= count\n",
    "        variance[i] = variance[i] / count - means[i] * means[i]\n",
    "        if variance[i] < 1.0e-20:\n",
    "            variance[i] = 1.0e-20\n",
    "        variance[i] = 1.0 / math.sqrt(variance[i])\n",
    "    cmvn = np.array([means, variance])\n",
    "    return cmvn\n",
    "\n",
    "\n",
    "def _load_kaldi_cmvn(kaldi_cmvn_file):\n",
    "    \"\"\" Load the kaldi format cmvn stats file and calculate cmvn\n",
    "\n",
    "    Args:\n",
    "        kaldi_cmvn_file:  kaldi text style global cmvn file, which\n",
    "           is generated by:\n",
    "           compute-cmvn-stats --binary=false scp:feats.scp global_cmvn\n",
    "\n",
    "    Returns:\n",
    "        a numpy array of [means, vars]\n",
    "    \"\"\"\n",
    "    means = []\n",
    "    variance = []\n",
    "    with open(kaldi_cmvn_file, 'r') as fid:\n",
    "        # kaldi binary file start with '\\0B'\n",
    "        if fid.read(2) == '\\0B':\n",
    "            logger.error('kaldi cmvn binary file is not supported, please '\n",
    "                         'recompute it by: compute-cmvn-stats --binary=false '\n",
    "                         ' scp:feats.scp global_cmvn')\n",
    "            sys.exit(1)\n",
    "        fid.seek(0)\n",
    "        arr = fid.read().split()\n",
    "        assert (arr[0] == '[')\n",
    "        assert (arr[-2] == '0')\n",
    "        assert (arr[-1] == ']')\n",
    "        feat_dim = int((len(arr) - 2 - 2) / 2)\n",
    "        for i in range(1, feat_dim + 1):\n",
    "            means.append(float(arr[i]))\n",
    "        count = float(arr[feat_dim + 1])\n",
    "        for i in range(feat_dim + 2, 2 * feat_dim + 2):\n",
    "            variance.append(float(arr[i]))\n",
    "\n",
    "    for i in range(len(means)):\n",
    "        means[i] /= count\n",
    "        variance[i] = variance[i] / count - means[i] * means[i]\n",
    "        if variance[i] < 1.0e-20:\n",
    "            variance[i] = 1.0e-20\n",
    "        variance[i] = 1.0 / math.sqrt(variance[i])\n",
    "    cmvn = np.array([means, variance])\n",
    "    return cmvn\n",
    "\n",
    "\n",
    "def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):\n",
    "    npzfile = np.load(npz_cmvn_file)\n",
    "    means = npzfile[\"mean\"]  #(1, D)\n",
    "    std = npzfile[\"std\"]  #(1, D)\n",
    "    std = np.clip(std, eps, None)\n",
    "    variance = 1.0 / std\n",
    "    cmvn = np.array([means, variance])\n",
    "    return cmvn\n",
    "\n",
    "\n",
    "def load_cmvn(cmvn_file: str, filetype: str):\n",
    "    \"\"\"load cmvn from file.\n",
    "\n",
    "    Args:\n",
    "        cmvn_file (str): cmvn path.\n",
    "        filetype (str): file type, optional[npz, json, kaldi].\n",
    "\n",
    "    Raises:\n",
    "        ValueError: file type not support.\n",
    "\n",
    "    Returns:\n",
    "        Tuple[np.ndarray, np.ndarray]: mean, istd\n",
    "    \"\"\"\n",
    "    assert filetype in ['npz', 'json', 'kaldi'], filetype\n",
    "    filetype = filetype.lower()\n",
    "    if filetype == \"json\":\n",
    "        cmvn = _load_json_cmvn(cmvn_file)\n",
    "    elif filetype == \"kaldi\":\n",
    "        cmvn = _load_kaldi_cmvn(cmvn_file)\n",
    "    elif filetype == \"npz\":\n",
    "        cmvn = _load_npz_cmvn(cmvn_file)\n",
    "    else:\n",
    "        raise ValueError(f\"cmvn file type no support: {filetype}\")\n",
    "    return cmvn[0], cmvn[1]\n",
    "\n",
    "mean, istd = load_cmvn(args.cmvn, 'json')\n",
    "print(mean.shape)\n",
    "print(istd.shape)\n",
    "print(mean)\n",
    "print(istd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3cfa5e23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ASRModel(\n",
      "  (encoder): ConformerEncoder(\n",
      "    (global_cmvn): GlobalCMVN()\n",
      "    (embed): Conv2dSubsampling4(\n",
      "      (conv): Sequential(\n",
      "        (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))\n",
      "        (1): ReLU()\n",
      "        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
      "        (3): ReLU()\n",
      "      )\n",
      "      (out): Sequential(\n",
      "        (0): Linear(in_features=4864, out_features=256, bias=True)\n",
      "      )\n",
      "      (pos_enc): RelPositionalEncoding(\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "      )\n",
      "    )\n",
      "    (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "    (encoders): ModuleList(\n",
      "      (0): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (1): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (2): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (3): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (4): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (5): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (6): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (7): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (8): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (9): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (10): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (11): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
      "          (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
      "          (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "          (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (decoder): TransformerDecoder(\n",
      "    (embed): Sequential(\n",
      "      (0): Embedding(4233, 256)\n",
      "      (1): PositionalEncoding(\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "      )\n",
      "    )\n",
      "    (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "    (output_layer): Linear(in_features=256, out_features=4233, bias=True)\n",
      "    (decoders): ModuleList(\n",
      "      (0): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (1): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (2): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (3): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (4): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "      (5): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
      "        )\n",
      "        (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (ctc): CTC(\n",
      "    (ctc_lo): Linear(in_features=256, out_features=4233, bias=True)\n",
      "    (ctc_loss): CTCLoss()\n",
      "  )\n",
      "  (criterion_att): LabelSmoothingLoss(\n",
      "    (criterion): KLDivLoss()\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Init asr model from configs\n",
    "model = init_asr_model(configs)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3c780af5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def summary(layer, print_func=print):\n",
    "    num_params = num_elements = 0\n",
    "    for name, param in layer.state_dict().items():\n",
    "        if print_func:\n",
    "            print_func(\n",
    "                \"{} | {} | {}\".format(name, param.shape, np.prod(param.shape)))\n",
    "        num_elements += np.prod(param.shape)\n",
    "        num_params += 1\n",
    "    if print_func:\n",
    "        print_func(\n",
    "            f\"Total parameters: {num_params}, {num_elements} elements.\"\n",
    "        )\n",
    "        \n",
    "def print_params(model, print_func=print):\n",
    "    if print_func is None:\n",
    "        return\n",
    "    total = 0.0\n",
    "    num_params = 0.0\n",
    "    for n, p in model.named_parameters():\n",
    "        msg = f\"{n} | {p.shape} | {np.prod(p.shape)} | {p.requires_grad}\"\n",
    "        total += np.prod(p.shape)\n",
    "        num_params += 1\n",
    "        if print_func:\n",
    "            print_func(msg)\n",
    "    if print_func:\n",
    "        print_func(f\"Total parameters: {num_params}, {total} elements.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e159a200",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder.global_cmvn.mean | torch.Size([80]) | 80\n",
      "encoder.global_cmvn.istd | torch.Size([80]) | 80\n",
      "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304\n",
      "encoder.embed.conv.0.bias | torch.Size([256]) | 256\n",
      "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824\n",
      "encoder.embed.conv.2.bias | torch.Size([256]) | 256\n",
      "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184\n",
      "encoder.embed.out.0.bias | torch.Size([256]) | 256\n",
      "encoder.after_norm.weight | torch.Size([256]) | 256\n",
      "encoder.after_norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.0.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.0.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.1.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.1.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.2.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.2.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.3.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.3.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.4.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.4.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.5.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.5.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.6.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.6.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.7.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.7.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.8.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.8.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.9.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.9.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.10.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.10.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
      "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
      "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
      "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
      "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
      "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
      "encoder.encoders.11.conv_module.norm.running_var | torch.Size([256]) | 256\n",
      "encoder.encoders.11.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
      "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
      "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256\n",
      "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256\n",
      "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
      "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256\n",
      "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648\n",
      "decoder.after_norm.weight | torch.Size([256]) | 256\n",
      "decoder.after_norm.bias | torch.Size([256]) | 256\n",
      "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648\n",
      "decoder.output_layer.bias | torch.Size([4233]) | 4233\n",
      "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
      "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
      "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
      "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
      "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256\n",
      "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256\n",
      "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
      "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256\n",
      "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648\n",
      "ctc.ctc_lo.bias | torch.Size([4233]) | 4233\n",
      "Total parameters: 701, 49355454.0 elements.\n"
     ]
    }
   ],
   "source": [
    "summary(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8494c6ab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0648a969",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304 | True\n",
      "encoder.embed.conv.0.bias | torch.Size([256]) | 256 | True\n",
      "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824 | True\n",
      "encoder.embed.conv.2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184 | True\n",
      "encoder.embed.out.0.bias | torch.Size([256]) | 256 | True\n",
      "encoder.after_norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.after_norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
      "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
      "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
      "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
      "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
      "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256 | True\n",
      "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
      "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256 | True\n",
      "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648 | True\n",
      "decoder.after_norm.weight | torch.Size([256]) | 256 | True\n",
      "decoder.after_norm.bias | torch.Size([256]) | 256 | True\n",
      "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648 | True\n",
      "decoder.output_layer.bias | torch.Size([4233]) | 4233 | True\n",
      "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
      "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
      "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
      "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
      "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
      "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
      "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
      "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648 | True\n",
      "ctc.ctc_lo.bias | torch.Size([4233]) | 4233 | True\n",
      "Total parameters: 663.0, 49349138.0 elements.\n"
     ]
    }
   ],
   "source": [
    "print_params(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5ad6de2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n",
      "torch.Size([16, 207, 80])\n",
      "tensor([[[ 8.9946,  9.5383,  9.1916,  ..., 10.5074,  9.5633,  8.2564],\n",
      "         [ 9.7988, 10.4052,  9.2651,  ..., 10.2512,  9.5440,  8.8738],\n",
      "         [10.6891, 10.3955,  8.0535,  ...,  9.9067, 10.0649,  8.0509],\n",
      "         ...,\n",
      "         [ 9.2180,  9.6507,  8.5053,  ...,  9.6872,  8.7425,  7.9865],\n",
      "         [10.1291,  9.9352,  9.3798,  ...,  9.5639,  9.8260,  8.9795],\n",
      "         [ 9.0955,  7.1338,  9.4680,  ...,  9.4727,  9.0212,  7.4479]],\n",
      "\n",
      "        [[11.4310, 10.6719,  6.0841,  ...,  9.3827,  8.7297,  7.5316],\n",
      "         [ 9.7317,  7.8105,  7.5715,  ..., 10.0430,  9.2436,  7.3541],\n",
      "         [10.6502, 10.6006,  8.4678,  ...,  9.2814,  9.1869,  8.0703],\n",
      "         ...,\n",
      "         [ 9.0970,  9.2637,  8.0753,  ...,  8.4318,  8.3705,  8.0029],\n",
      "         [10.4617, 10.1478,  6.7693,  ...,  9.7794,  9.5775,  8.0807],\n",
      "         [ 7.7944,  5.6211,  7.9751,  ...,  9.9972,  9.8497,  8.0313]],\n",
      "\n",
      "        [[ 7.3456,  7.8964,  7.5796,  ..., 11.6310, 10.4513,  9.1236],\n",
      "         [ 8.6287,  8.4631,  7.4992,  ..., 12.4160, 10.9757,  8.9426],\n",
      "         [ 9.8314, 10.2813,  8.9724,  ..., 12.1387, 10.4017,  9.0055],\n",
      "         ...,\n",
      "         [ 7.0896,  7.4055,  6.8143,  ...,  9.3252,  9.2732,  8.3534],\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[10.9332, 10.4644,  7.7203,  ..., 10.3488,  9.3023,  7.1553],\n",
      "         [10.4499,  9.9070,  9.0293,  ...,  9.9525,  9.4141,  7.5593],\n",
      "         [10.4877,  9.8126,  9.8952,  ...,  9.5866,  9.3413,  7.7849],\n",
      "         ...,\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
      "\n",
      "        [[ 9.9444,  9.5859,  8.2203,  ..., 11.5886, 11.0450,  8.8171],\n",
      "         [ 7.6784,  8.3224,  7.5330,  ..., 11.0551, 10.5357,  9.2746],\n",
      "         [ 8.6262,  9.6759,  9.8410,  ..., 11.3788, 10.9221,  8.9914],\n",
      "         ...,\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
      "\n",
      "        [[ 8.1079,  7.7590,  6.7103,  ..., 12.6506, 11.4662, 11.0615],\n",
      "         [11.3803, 11.2220,  8.6589,  ..., 12.8106, 12.2222, 11.6893],\n",
      "         [10.6777,  9.9206,  8.0461,  ..., 13.5729, 12.5624, 11.1550],\n",
      "         ...,\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
      "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])\n",
      "tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n",
      "        166, 163], dtype=torch.int32)\n",
      "tensor([[2995, 3116, 1209,  565,   -1,   -1],\n",
      "        [ 236, 1176,  331,   66, 3925, 4077],\n",
      "        [2693,  524,  234, 1145,  366,   -1],\n",
      "        [3875, 4211, 3062,  700,   -1,   -1],\n",
      "        [ 272,  987, 1134,  494, 2959,   -1],\n",
      "        [1936, 3715,  120, 2553, 2695, 2710],\n",
      "        [  25, 1149, 3930,   -1,   -1,   -1],\n",
      "        [1753, 1778, 1237,  482, 3925,  110],\n",
      "        [3703,    2,  565, 3827,   -1,   -1],\n",
      "        [1150, 2734,   10, 2478, 3490,   -1],\n",
      "        [ 426,  811,   95,  489,  144,   -1],\n",
      "        [2313, 2006,  489,  975,   -1,   -1],\n",
      "        [3702, 3414,  205, 1488, 2966, 1347],\n",
      "        [  70, 1741,  702, 1666,   -1,   -1],\n",
      "        [ 703, 1778, 1030,  849,   -1,   -1],\n",
      "        [ 814, 1674,  115, 3827,   -1,   -1]], dtype=torch.int32)\n",
      "tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "for batch in cv_data_loader:\n",
    "    keys, feat, text, feat_len, text_len = batch\n",
    "    print(keys)\n",
    "    print(feat.shape)\n",
    "    print(feat)\n",
    "    print(feat_len)\n",
    "    print(text)\n",
    "    print(text_len)\n",
    "    np.savez('data.npz', keys=keys, feat=feat.numpy(), feat_len=feat_len.numpy(), text=text.numpy(), text_len=text_len.numpy())\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "852a9c95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CODE_OF_CONDUCT.md  data.npz  install.sh  README.md\t    tools\r\n",
      "CONTRIBUTING.md     docs      LICENSE\t  requirements.txt  venv\r\n",
      "CPPLINT.cfg\t    examples  Makefile\t  runtime\t    wenet\r\n"
     ]
    }
   ],
   "source": [
    "!ls\n",
    "!cp data.npz /workspace/DeepSpeech-2.x/.notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cde24c4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(111.9988)\n",
      "tensor(830.9634, grad_fn=<SumBackward0>)\n",
      "tensor([False, False, False, False, False,  True,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False,  True,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False,  True,  True, False, False, False, False, False,\n",
      "         True,  True])\n",
      "tensor(669.4633, grad_fn=<SumBackward0>)\n",
      "tensor(142.4888, grad_fn=<AddBackward0>) tensor(41.8415, grad_fn=<DivBackward0>) tensor(377.3326, grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "model.cpu().eval()\n",
    "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n",
    "                                         text, text_len)\n",
    "print(total_loss, attention_loss, ctc_loss )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "be5b2a2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cpu\n"
     ]
    }
   ],
   "source": [
    "print(total_loss.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5b791771",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(112., device='cuda:0')\n",
      "tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor([False, False, False, False, False,  True,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False,  True,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False,  True,  True, False, False, False, False, False,\n",
      "         True,  True], device='cuda:0')\n",
      "tensor(669.4634, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "cuda:0\n",
      "142.4888 41.84146 377.33258\n"
     ]
    }
   ],
   "source": [
    "model.cuda().eval()\n",
    "feat=feat.cuda()\n",
    "feat_len=feat_len.cuda()\n",
    "text=text.cuda()\n",
    "text_len=text_len.cuda()\n",
    "\n",
    "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n",
    "                                         text, text_len)\n",
    "print(total_loss.device)\n",
    "print(total_loss.cpu().data.numpy(), attention_loss.cpu().data.numpy(), ctc_loss.cpu().data.numpy() )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1baef537",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 51, 256])\n",
      "torch.Size([16, 1, 51])\n",
      "tensor([[-0.7019,  0.5625,  0.6880,  ...,  1.1237,  0.7804,  1.1369],\n",
      "        [-0.7788,  0.3913,  0.7189,  ...,  1.2519,  0.8862,  1.3173],\n",
      "        [-0.9591,  0.6346,  0.8767,  ...,  0.9818,  0.7440,  1.2903],\n",
      "        ...,\n",
      "        [-1.0732,  0.6724,  0.9230,  ...,  0.9075,  0.8177,  1.3240],\n",
      "        [-1.1654,  0.6820,  0.6939,  ...,  1.2238,  0.8028,  1.4507],\n",
      "        [-1.2732,  0.7146,  0.7582,  ...,  0.9415,  0.8775,  1.2623]],\n",
      "       device='cuda:0', grad_fn=<SelectBackward>)\n"
     ]
    }
   ],
   "source": [
    "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n",
    "print(encoder_out.shape)\n",
    "print(encoder_mask.shape)\n",
    "print(encoder_out[0])\n",
    "\n",
    "np.savez('/workspace/DeepSpeech-2.x/.notebook/encoder.npz',\n",
    "         mask=encoder_mask.cpu().detach().numpy(),  \n",
    "         out=encoder_out.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e22c782",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "30b6b946",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 9.871763   9.938915  10.238187  10.8597145 11.686526  12.25488\n",
      " 12.657681  12.86139   12.807339  12.566256  12.32007   12.138792\n",
      " 12.313189  12.552552  12.612239  12.569745  12.389728  12.143833\n",
      " 12.092851  11.793959  11.622591  11.926331  11.815442  11.951225\n",
      " 11.831805  11.887888  11.790144  11.88072   11.900057  11.973481\n",
      " 12.009822  12.008814  12.026197  12.104796  12.21555   12.343993\n",
      " 12.450144  12.496688  12.486538  12.355079  12.392918  12.255374\n",
      " 12.264963  12.253142  12.325458  12.4335985 12.548675  12.676334\n",
      " 12.809207  12.929347  12.961151  12.968834  12.995931  13.047281\n",
      " 13.058881  13.05738   12.999211  12.934022  12.874292  12.71653\n",
      " 12.48942   12.274784  12.261631  12.286319  12.31956   12.422907\n",
      " 12.514802  12.578516  12.647194  12.737626  12.800171  12.868728\n",
      " 12.966668  13.064786  13.159159  13.272843  13.310819  13.239043\n",
      " 12.879361  11.183102 ] float32\n",
      "encoder.embed.out.0.weight: (256, 4864) -> (4864, 256)\n",
      "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.0.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.0.conv_module.norm.running_mean -> encoder.encoders.0.conv_module.norm._mean\n",
      "encoder.encoders.0.conv_module.norm.running_var -> encoder.encoders.0.conv_module.norm._variance\n",
      "encoder.encoders.0.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.1.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.1.conv_module.norm.running_mean -> encoder.encoders.1.conv_module.norm._mean\n",
      "encoder.encoders.1.conv_module.norm.running_var -> encoder.encoders.1.conv_module.norm._variance\n",
      "encoder.encoders.1.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.2.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.2.conv_module.norm.running_mean -> encoder.encoders.2.conv_module.norm._mean\n",
      "encoder.encoders.2.conv_module.norm.running_var -> encoder.encoders.2.conv_module.norm._variance\n",
      "encoder.encoders.2.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.3.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.3.conv_module.norm.running_mean -> encoder.encoders.3.conv_module.norm._mean\n",
      "encoder.encoders.3.conv_module.norm.running_var -> encoder.encoders.3.conv_module.norm._variance\n",
      "encoder.encoders.3.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.4.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.4.conv_module.norm.running_mean -> encoder.encoders.4.conv_module.norm._mean\n",
      "encoder.encoders.4.conv_module.norm.running_var -> encoder.encoders.4.conv_module.norm._variance\n",
      "encoder.encoders.4.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.5.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.5.conv_module.norm.running_mean -> encoder.encoders.5.conv_module.norm._mean\n",
      "encoder.encoders.5.conv_module.norm.running_var -> encoder.encoders.5.conv_module.norm._variance\n",
      "encoder.encoders.5.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.6.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.6.conv_module.norm.running_mean -> encoder.encoders.6.conv_module.norm._mean\n",
      "encoder.encoders.6.conv_module.norm.running_var -> encoder.encoders.6.conv_module.norm._variance\n",
      "encoder.encoders.6.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.7.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.7.conv_module.norm.running_mean -> encoder.encoders.7.conv_module.norm._mean\n",
      "encoder.encoders.7.conv_module.norm.running_var -> encoder.encoders.7.conv_module.norm._variance\n",
      "encoder.encoders.7.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.8.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.8.conv_module.norm.running_mean -> encoder.encoders.8.conv_module.norm._mean\n",
      "encoder.encoders.8.conv_module.norm.running_var -> encoder.encoders.8.conv_module.norm._variance\n",
      "encoder.encoders.8.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.9.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.9.conv_module.norm.running_mean -> encoder.encoders.9.conv_module.norm._mean\n",
      "encoder.encoders.9.conv_module.norm.running_var -> encoder.encoders.9.conv_module.norm._variance\n",
      "encoder.encoders.9.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.10.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.10.conv_module.norm.running_mean -> encoder.encoders.10.conv_module.norm._mean\n",
      "encoder.encoders.10.conv_module.norm.running_var -> encoder.encoders.10.conv_module.norm._variance\n",
      "encoder.encoders.10.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.11.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
      "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "encoder.encoders.11.conv_module.norm.running_mean -> encoder.encoders.11.conv_module.norm._mean\n",
      "encoder.encoders.11.conv_module.norm.running_var -> encoder.encoders.11.conv_module.norm._variance\n",
      "encoder.encoders.11.concat_linear.weight: (256, 512) -> (512, 256)\n",
      "decoder.output_layer.weight: (4233, 256) -> (256, 4233)\n",
      "decoder.decoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.0.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.0.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.0.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.0.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "decoder.decoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "decoder.decoders.0.concat_linear1.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.0.concat_linear2.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.1.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.1.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.1.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.1.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "decoder.decoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "decoder.decoders.1.concat_linear1.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.1.concat_linear2.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.2.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.2.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.2.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.2.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "decoder.decoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "decoder.decoders.2.concat_linear1.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.2.concat_linear2.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.3.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.3.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.3.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.3.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "decoder.decoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "decoder.decoders.3.concat_linear1.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.3.concat_linear2.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.4.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.4.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.4.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.4.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "decoder.decoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "decoder.decoders.4.concat_linear1.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.4.concat_linear2.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.5.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.5.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.5.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.5.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
      "decoder.decoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
      "decoder.decoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
      "decoder.decoders.5.concat_linear1.weight: (256, 512) -> (512, 256)\n",
      "decoder.decoders.5.concat_linear2.weight: (256, 512) -> (512, 256)\n",
      "ctc.ctc_lo.weight: (4233, 256) -> (256, 4233)\n"
     ]
    }
   ],
   "source": [
    "# dump torch model to paddle\n",
    "import numpy as np\n",
    "state_dict = model.state_dict()\n",
    "paddle_state_dict = {}\n",
    "\n",
    "for n, p in state_dict.items():\n",
    "    name_change=True\n",
    "\n",
    "    if 'norm.running_mean' in n:\n",
    "        new_n = n.replace('norm.running_', 'norm._')\n",
    "    elif 'norm.running_var' in n:\n",
    "        new_n = n.replace('norm.running_var', 'norm._variance')\n",
    "    else:\n",
    "        name_change=False\n",
    "        new_n = n\n",
    "    \n",
    "    if name_change:\n",
    "        print(f\"{n} -> {new_n}\")\n",
    "        \n",
    "    p = p.cpu().detach().numpy()\n",
    "    if n.endswith('weight') and p.ndim == 2 and 'embed.0.weight' not in n:\n",
    "        new_p = p.T\n",
    "        print(f\"{n}: {p.shape} -> {new_p.shape}\")\n",
    "    else:\n",
    "        new_p = p\n",
    "        \n",
    "    if 'global_cmvn.mean' in n:\n",
    "        print(p, p.dtype)\n",
    "        \n",
    "    paddle_state_dict[new_n] = new_p\n",
    "    \n",
    "np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n",
    "       state=paddle_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7307dc5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d99b29bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(377.3326, device='cuda:0', grad_fn=<DivBackward0>)\n",
      "None\n",
      "[[ 3.16902351e+00 -1.51765049e-02  4.91097234e-02 ... -2.47973716e-03\n",
      "  -5.93366381e-03 -7.26613170e-03]\n",
      " [-1.74185038e+00  7.75875803e-03 -4.49435972e-02 ...  9.92415240e-04\n",
      "   2.46338220e-03  2.31891591e-03]\n",
      " [-2.33343077e+00  1.30476682e-02 -2.66557615e-02 ...  2.27533933e-03\n",
      "   5.76929189e-03  7.48792710e-03]\n",
      " ...\n",
      " [-4.30356789e+00  2.46056803e-02 -9.00955945e-02 ...  4.43160534e-03\n",
      "   1.16123557e-02  1.44716976e-02]\n",
      " [-3.36919212e+00  1.73155665e-02 -6.36875406e-02 ...  3.28367390e-03\n",
      "   8.58021621e-03  1.07796099e-02]\n",
      " [-6.62039661e+00  3.49958315e-02 -1.23963736e-01 ...  6.36674836e-03\n",
      "   1.60815325e-02  2.03892551e-02]]\n",
      "[-4.3777566e+00  2.3245990e-02 -9.3339972e-02 ...  4.2569702e-03\n",
      "  1.0920014e-02  1.3787906e-02]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-21-622603015bd4>:6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n",
      "  print(loss_ctc.grad)\n"
     ]
    }
   ],
   "source": [
    "encoder_out_lens = encoder_mask.squeeze(1).sum(1)\n",
    "loss_ctc = model.ctc(encoder_out, encoder_out_lens, text, text_len)\n",
    "print(loss_ctc)\n",
    "dir(loss_ctc)\n",
    "loss_ctc.backward()\n",
    "print(loss_ctc.grad)\n",
    "#print(model.ctc.ctc_lo.weight.grad)\n",
    "print(model.ctc.ctc_lo.weight.grad.T.cpu().data.numpy())\n",
    "print(model.ctc.ctc_lo.bias.grad.cpu().data.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "49b05d6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(112., device='cuda:0')\n",
      "tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor([False, False, False, False, False,  True,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False,  True,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False,  True,  True, False, False, False, False, False,\n",
      "         True,  True], device='cuda:0')\n",
      "tensor(669.4634, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor(41.8415, device='cuda:0', grad_fn=<DivBackward0>) 0.0\n"
     ]
    }
   ],
   "source": [
    "loss_att, acc_att = model._calc_att_loss(encoder_out, encoder_mask,\n",
    "                                                    text, text_len)\n",
    "print(loss_att, acc_att)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "413b413f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_list(xs, pad_value: int):\n",
    "    n_batch = len(xs)\n",
    "    max_len = max([x.size(0) for x in xs])\n",
    "    pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)\n",
    "    pad = pad.fill_(pad_value)\n",
    "    for i in range(n_batch):\n",
    "        pad[i, :xs[i].size(0)] = xs[i]\n",
    "\n",
    "    return pad\n",
    "\n",
    "def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,\n",
    "                ignore_id: int):\n",
    "\n",
    "    _sos = torch.tensor([sos],\n",
    "                        dtype=torch.long,\n",
    "                        requires_grad=False,\n",
    "                        device=ys_pad.device)\n",
    "    _eos = torch.tensor([eos],\n",
    "                        dtype=torch.long,\n",
    "                        requires_grad=False,\n",
    "                        device=ys_pad.device)\n",
    "    ys = [y[y != ignore_id] for y in ys_pad]  # parse padded ys\n",
    "    ys_in = [torch.cat([_sos, y], dim=0) for y in ys]\n",
    "    ys_out = [torch.cat([y, _eos], dim=0) for y in ys]\n",
    "    return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "ff0c2400",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[4232, 2995, 3116, 1209,  565, 4232, 4232],\n",
      "        [4232,  236, 1176,  331,   66, 3925, 4077],\n",
      "        [4232, 2693,  524,  234, 1145,  366, 4232],\n",
      "        [4232, 3875, 4211, 3062,  700, 4232, 4232],\n",
      "        [4232,  272,  987, 1134,  494, 2959, 4232],\n",
      "        [4232, 1936, 3715,  120, 2553, 2695, 2710],\n",
      "        [4232,   25, 1149, 3930, 4232, 4232, 4232],\n",
      "        [4232, 1753, 1778, 1237,  482, 3925,  110],\n",
      "        [4232, 3703,    2,  565, 3827, 4232, 4232],\n",
      "        [4232, 1150, 2734,   10, 2478, 3490, 4232],\n",
      "        [4232,  426,  811,   95,  489,  144, 4232],\n",
      "        [4232, 2313, 2006,  489,  975, 4232, 4232],\n",
      "        [4232, 3702, 3414,  205, 1488, 2966, 1347],\n",
      "        [4232,   70, 1741,  702, 1666, 4232, 4232],\n",
      "        [4232,  703, 1778, 1030,  849, 4232, 4232],\n",
      "        [4232,  814, 1674,  115, 3827, 4232, 4232]], device='cuda:0')\n",
      "tensor([[2995, 3116, 1209,  565, 4232,   -1,   -1],\n",
      "        [ 236, 1176,  331,   66, 3925, 4077, 4232],\n",
      "        [2693,  524,  234, 1145,  366, 4232,   -1],\n",
      "        [3875, 4211, 3062,  700, 4232,   -1,   -1],\n",
      "        [ 272,  987, 1134,  494, 2959, 4232,   -1],\n",
      "        [1936, 3715,  120, 2553, 2695, 2710, 4232],\n",
      "        [  25, 1149, 3930, 4232,   -1,   -1,   -1],\n",
      "        [1753, 1778, 1237,  482, 3925,  110, 4232],\n",
      "        [3703,    2,  565, 3827, 4232,   -1,   -1],\n",
      "        [1150, 2734,   10, 2478, 3490, 4232,   -1],\n",
      "        [ 426,  811,   95,  489,  144, 4232,   -1],\n",
      "        [2313, 2006,  489,  975, 4232,   -1,   -1],\n",
      "        [3702, 3414,  205, 1488, 2966, 1347, 4232],\n",
      "        [  70, 1741,  702, 1666, 4232,   -1,   -1],\n",
      "        [ 703, 1778, 1030,  849, 4232,   -1,   -1],\n",
      "        [ 814, 1674,  115, 3827, 4232,   -1,   -1]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "ys_pad = text\n",
    "ys_pad_lens = text_len\n",
    "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n",
    "                                            model.ignore_id)\n",
    "ys_in_lens = ys_pad_lens + 1\n",
    "print(ys_in_pad)\n",
    "print(ys_out_pad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "3e84da38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 7, 4233])\n",
      "tensor([[-3.7639e-01, -8.2272e-01,  7.4276e-01,  ...,  3.4201e-01,\n",
      "          1.5035e-02,  4.0337e-01],\n",
      "        [-8.7386e-01, -3.1389e-01,  4.1988e-01,  ...,  3.7724e-01,\n",
      "         -1.4353e-01, -1.0024e+00],\n",
      "        [-4.3505e-01,  3.4505e-02, -2.8710e-01,  ...,  7.7274e-02,\n",
      "         -1.1672e+00, -2.6849e-01],\n",
      "        ...,\n",
      "        [ 4.2471e-01,  5.8886e-01,  2.0204e-02,  ...,  3.7405e-01,\n",
      "          4.5470e-02, -3.7139e-01],\n",
      "        [-3.7978e-01, -8.1084e-01,  7.5725e-01,  ...,  2.6039e-01,\n",
      "         -7.9347e-04,  4.2538e-01],\n",
      "        [-3.8280e-01, -8.1207e-01,  7.4943e-01,  ...,  2.6173e-01,\n",
      "         -1.0499e-03,  4.2679e-01]], device='cuda:0', grad_fn=<SelectBackward>)\n"
     ]
    }
   ],
   "source": [
    "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n",
    "                                      ys_in_lens)\n",
    "print(decoder_out.shape)\n",
    "print(decoder_out[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aac441ea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5ddbca73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n",
      "torch.int64\n",
      "tensor(112., device='cuda:0')\n",
      "tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor([False, False, False, False, False,  True,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False,  True,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False,  True,  True, False, False, False, False, False,\n",
      "         True,  True], device='cuda:0')\n",
      "tensor(669.4634, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor(41.8415, device='cuda:0', grad_fn=<DivBackward0>)\n",
      "tensor([[2995, 3116, 1209,  565, 4232,   -1,   -1],\n",
      "        [ 236, 1176,  331,   66, 3925, 4077, 4232],\n",
      "        [2693,  524,  234, 1145,  366, 4232,   -1],\n",
      "        [3875, 4211, 3062,  700, 4232,   -1,   -1],\n",
      "        [ 272,  987, 1134,  494, 2959, 4232,   -1],\n",
      "        [1936, 3715,  120, 2553, 2695, 2710, 4232],\n",
      "        [  25, 1149, 3930, 4232,   -1,   -1,   -1],\n",
      "        [1753, 1778, 1237,  482, 3925,  110, 4232],\n",
      "        [3703,    2,  565, 3827, 4232,   -1,   -1],\n",
      "        [1150, 2734,   10, 2478, 3490, 4232,   -1],\n",
      "        [ 426,  811,   95,  489,  144, 4232,   -1],\n",
      "        [2313, 2006,  489,  975, 4232,   -1,   -1],\n",
      "        [3702, 3414,  205, 1488, 2966, 1347, 4232],\n",
      "        [  70, 1741,  702, 1666, 4232,   -1,   -1],\n",
      "        [ 703, 1778, 1030,  849, 4232,   -1,   -1],\n",
      "        [ 814, 1674,  115, 3827, 4232,   -1,   -1]], device='cuda:0')\n",
      "tensor([[-3.7639e-01, -8.2272e-01,  7.4276e-01,  ...,  3.4201e-01,\n",
      "          1.5035e-02,  4.0337e-01],\n",
      "        [-8.7386e-01, -3.1389e-01,  4.1988e-01,  ...,  3.7724e-01,\n",
      "         -1.4353e-01, -1.0024e+00],\n",
      "        [-4.3505e-01,  3.4505e-02, -2.8710e-01,  ...,  7.7274e-02,\n",
      "         -1.1672e+00, -2.6849e-01],\n",
      "        ...,\n",
      "        [ 4.2471e-01,  5.8886e-01,  2.0204e-02,  ...,  3.7405e-01,\n",
      "          4.5470e-02, -3.7139e-01],\n",
      "        [-3.7978e-01, -8.1084e-01,  7.5725e-01,  ...,  2.6039e-01,\n",
      "         -7.9347e-04,  4.2538e-01],\n",
      "        [-3.8280e-01, -8.1207e-01,  7.4943e-01,  ...,  2.6173e-01,\n",
      "         -1.0499e-03,  4.2679e-01]], device='cuda:0', grad_fn=<SelectBackward>)\n"
     ]
    }
   ],
   "source": [
    "print(decoder_out.dtype)\n",
    "print(ys_out_pad.dtype)\n",
    "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n",
    "print(loss_att)\n",
    "print(ys_out_pad)\n",
    "print(decoder_out[0])\n",
    "np.savez('/workspace/DeepSpeech-2.x/.notebook/decoder',\n",
    "       decoder_out=decoder_out.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78f98c0b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "8d968cd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "class LabelSmoothingLoss(nn.Module):\n",
    "    def __init__(self,\n",
    "                 size: int,\n",
    "                 padding_idx: int,\n",
    "                 smoothing: float,\n",
    "                 normalize_length: bool = False):\n",
    "        \"\"\"Construct an LabelSmoothingLoss object.\"\"\"\n",
    "        super(LabelSmoothingLoss, self).__init__()\n",
    "        self.criterion = nn.KLDivLoss(reduction=\"none\")\n",
    "        self.padding_idx = padding_idx\n",
    "        self.confidence = 1.0 - smoothing\n",
    "        self.smoothing = smoothing\n",
    "        self.size = size\n",
    "        self.normalize_length = normalize_length\n",
    "\n",
    "    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"Compute loss between x and target.\n",
    "\n",
    "        The model outputs and data labels tensors are flatten to\n",
    "        (batch*seqlen, class) shape and a mask is applied to the\n",
    "        padding part which should not be calculated for loss.\n",
    "\n",
    "        Args:\n",
    "            x (torch.Tensor): prediction (batch, seqlen, class)\n",
    "            target (torch.Tensor):\n",
    "                target signal masked with self.padding_id (batch, seqlen)\n",
    "        Returns:\n",
    "            loss (torch.Tensor) : The KL loss, scalar float value\n",
    "        \"\"\"\n",
    "        assert x.size(2) == self.size\n",
    "        batch_size = x.size(0)\n",
    "        x = x.view(-1, self.size)\n",
    "        target = target.view(-1)\n",
    "        # use zeros_like instead of torch.no_grad() for true_dist,\n",
    "        # since no_grad() can not be exported by JIT\n",
    "        true_dist = torch.zeros_like(x)\n",
    "        true_dist.fill_(self.smoothing / (self.size - 1))\n",
    "        ignore = target == self.padding_idx  # (B,)\n",
    "        print(self.smoothing / (self.size - 1))\n",
    "        print(true_dist)\n",
    "        total = len(target) - ignore.sum().item()\n",
    "        target = target.masked_fill(ignore, 0)  # avoid -1 index\n",
    "        true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n",
    "        print(true_dist.dtype)\n",
    "        print(true_dist.square().sum())\n",
    "        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)\n",
    "        print(kl.sum())\n",
    "        denom = total if self.normalize_length else batch_size\n",
    "        print(ignore)\n",
    "        numer= kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n",
    "        print(numer)\n",
    "        return numer /denom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3df340ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3629489603024576e-05\n",
      "tensor([[2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,\n",
      "         2.3629e-05],\n",
      "        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,\n",
      "         2.3629e-05],\n",
      "        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,\n",
      "         2.3629e-05],\n",
      "        ...,\n",
      "        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,\n",
      "         2.3629e-05],\n",
      "        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,\n",
      "         2.3629e-05],\n",
      "        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,\n",
      "         2.3629e-05]], device='cuda:0')\n",
      "torch.float32\n",
      "tensor(90.7203, device='cuda:0')\n",
      "tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor([False, False, False, False, False,  True,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False,  True,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False,  True,  True, False, False, False, False, False,\n",
      "         True,  True], device='cuda:0')\n",
      "tensor(669.4634, device='cuda:0', grad_fn=<SumBackward0>)\n",
      "tensor(41.8415, device='cuda:0', grad_fn=<DivBackward0>)\n",
      "torch.int64\n"
     ]
    }
   ],
   "source": [
    "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n",
    "loss_att = criteron(decoder_out, ys_out_pad)\n",
    "print(loss_att)\n",
    "print(ys_out_pad.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "badc410d",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-31-e578f72cfac5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecoder_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    183\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    184\u001b[0m         \"\"\"\n\u001b[0;32m--> 185\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    187\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m    123\u001b[0m         \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m     Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m    126\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    127\u001b[0m         allow_unreachable=True)  # allow_unreachable flag\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time."
     ]
    }
   ],
   "source": [
    "loss_att.backward()\n",
    "print(loss_att.grad)\n",
    "print(decoder_out.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "219eb41f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.0024,  0.0019, -0.1098,  ...,  0.0028,  0.0020, -1.7978],\n",
      "       device='cuda:0')\n",
      "tensor([[ 6.5052e-04,  6.4419e-05, -6.1955e-06,  ...,  9.8220e-04,\n",
      "         -2.5918e-05,  3.3754e-04],\n",
      "        [ 3.9305e-04,  4.5799e-04,  1.4362e-04,  ...,  4.6800e-04,\n",
      "          1.6911e-04,  2.7067e-04],\n",
      "        [-1.3593e-01,  5.2201e-02,  3.2895e-02,  ...,  2.4580e-02,\n",
      "          1.4590e-01, -4.6850e-02],\n",
      "        ...,\n",
      "        [ 1.0434e-03,  4.2251e-04,  6.5688e-04,  ...,  1.2144e-03,\n",
      "          2.1159e-04,  6.6838e-04],\n",
      "        [ 6.4997e-04,  4.4301e-04,  4.1550e-04,  ...,  1.0420e-03,\n",
      "          2.4114e-04,  1.5338e-04],\n",
      "        [-9.9337e-01,  5.4573e-01, -1.1371e-02,  ..., -4.3175e-01,\n",
      "         -2.7850e-01, -4.4679e-01]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "print(model.decoder.output_layer.bias.grad)\n",
    "print(model.decoder.output_layer.weight.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "40d00a54",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-5.3698e-01, -1.9911e-01, -3.4997e-01,  ..., -8.2428e-01,\n",
      "          -1.0265e+00, -9.6301e-01],\n",
      "         [-4.4642e-02,  2.3176e-01, -3.2539e-01,  ..., -9.0159e-01,\n",
      "          -1.0325e+00, -7.5987e-01],\n",
      "         [ 5.0035e-01,  2.2691e-01, -7.3052e-01,  ..., -1.0055e+00,\n",
      "          -8.7123e-01, -1.0306e+00],\n",
      "         ...,\n",
      "         [-4.0024e-01, -1.4325e-01, -5.7947e-01,  ..., -1.0718e+00,\n",
      "          -1.2806e+00, -1.0518e+00],\n",
      "         [ 1.5755e-01, -1.8495e-03, -2.8703e-01,  ..., -1.1090e+00,\n",
      "          -9.4519e-01, -7.2506e-01],\n",
      "         [-4.7520e-01, -1.3942e+00, -2.5754e-01,  ..., -1.1365e+00,\n",
      "          -1.1943e+00, -1.2290e+00]],\n",
      "\n",
      "        [[ 9.5454e-01,  3.6428e-01, -1.3891e+00,  ..., -1.1637e+00,\n",
      "          -1.2845e+00, -1.2015e+00],\n",
      "         [-8.5735e-02, -1.0579e+00, -8.9173e-01,  ..., -9.6441e-01,\n",
      "          -1.1255e+00, -1.2599e+00],\n",
      "         [ 4.7654e-01,  3.2887e-01, -5.9201e-01,  ..., -1.1942e+00,\n",
      "          -1.1430e+00, -1.0242e+00],\n",
      "         ...,\n",
      "         [-4.7431e-01, -3.3559e-01, -7.2326e-01,  ..., -1.4506e+00,\n",
      "          -1.3957e+00, -1.0464e+00],\n",
      "         [ 3.6113e-01,  1.0381e-01, -1.1599e+00,  ..., -1.0439e+00,\n",
      "          -1.0221e+00, -1.0208e+00],\n",
      "         [-1.2717e+00, -2.1460e+00, -7.5677e-01,  ..., -9.7822e-01,\n",
      "          -9.3785e-01, -1.0371e+00]],\n",
      "\n",
      "        [[-1.5465e+00, -1.0152e+00, -8.8901e-01,  ..., -4.8522e-01,\n",
      "          -7.5163e-01, -6.7765e-01],\n",
      "         [-7.6101e-01, -7.3352e-01, -9.1588e-01,  ..., -2.4836e-01,\n",
      "          -5.8927e-01, -7.3723e-01],\n",
      "         [-2.4714e-02,  1.7016e-01, -4.2326e-01,  ..., -3.3204e-01,\n",
      "          -7.6696e-01, -7.1652e-01],\n",
      "         ...,\n",
      "         [-1.7032e+00, -1.2591e+00, -1.1449e+00,  ..., -1.1810e+00,\n",
      "          -1.1163e+00, -9.3108e-01],\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00],\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[ 6.4983e-01,  2.6117e-01, -8.4197e-01,  ..., -8.7213e-01,\n",
      "          -1.1073e+00, -1.3253e+00],\n",
      "         [ 3.5391e-01, -1.5846e-02, -4.0425e-01,  ..., -9.9173e-01,\n",
      "          -1.0727e+00, -1.1924e+00],\n",
      "         [ 3.7704e-01, -6.2785e-02, -1.1468e-01,  ..., -1.1021e+00,\n",
      "          -1.0952e+00, -1.1182e+00],\n",
      "         ...,\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00],\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00],\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00]],\n",
      "\n",
      "        [[ 4.4458e-02, -1.7547e-01, -6.7475e-01,  ..., -4.9801e-01,\n",
      "          -5.6783e-01, -7.7852e-01],\n",
      "         [-1.3428e+00, -8.0343e-01, -9.0457e-01,  ..., -6.5902e-01,\n",
      "          -7.2550e-01, -6.2796e-01],\n",
      "         [-7.6253e-01, -1.3071e-01, -1.3280e-01,  ..., -5.6133e-01,\n",
      "          -6.0588e-01, -7.2115e-01],\n",
      "         ...,\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00],\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00],\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00]],\n",
      "\n",
      "        [[-1.0798e+00, -1.0834e+00, -1.1797e+00,  ..., -1.7757e-01,\n",
      "          -4.3747e-01, -4.0007e-02],\n",
      "         [ 9.2354e-01,  6.3771e-01, -5.2810e-01,  ..., -1.2928e-01,\n",
      "          -2.0342e-01,  1.6656e-01],\n",
      "         [ 4.9337e-01, -9.1133e-03, -7.3302e-01,  ...,  1.0074e-01,\n",
      "          -9.8115e-02, -9.2357e-03],\n",
      "         ...,\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00],\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00],\n",
      "         [-6.0434e+00, -4.9397e+00, -3.4235e+00,  ..., -3.9949e+00,\n",
      "          -3.9869e+00, -3.6797e+00]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "print(xs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "505ca294",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ True,  True,  True,  ...,  True,  True,  True]],\n",
      "\n",
      "        [[ True,  True,  True,  ...,  True,  True,  True]],\n",
      "\n",
      "        [[ True,  True,  True,  ...,  True, False, False]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[ True,  True,  True,  ..., False, False, False]],\n",
      "\n",
      "        [[ True,  True,  True,  ..., False, False, False]],\n",
      "\n",
      "        [[ True,  True,  True,  ..., False, False, False]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "from wenet.utils.mask import make_pad_mask\n",
    "masks = ~make_pad_mask(feat_len).unsqueeze(1)  # (B, 1, L)\n",
    "print(masks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "aa03c2b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "xs, pos_emb, masks = model.encoder.embed(xs, masks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "ebc0ea12",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-0.5482,  2.2866, -1.0750,  ...,  1.4504,  0.2895, -0.6945],\n",
      "         [-0.8013,  1.7688, -1.6639,  ...,  1.8332,  0.6791, -0.2000],\n",
      "         [-1.7112,  2.7057, -1.3363,  ...,  1.2336,  0.1870, -0.5735],\n",
      "         ...,\n",
      "         [-0.9697,  2.3129, -0.8752,  ...,  0.8584,  0.4853, -0.4177],\n",
      "         [-1.3609,  2.1779, -1.7813,  ...,  2.0928,  0.2528, -0.3650],\n",
      "         [-1.6967,  2.3544, -1.7417,  ...,  1.3670,  0.5951, -0.7415]],\n",
      "\n",
      "        [[-1.9828,  2.3178, -0.9079,  ...,  0.4117,  0.5006,  0.0872],\n",
      "         [-0.7640,  1.3558, -1.3613,  ...,  0.7317,  0.6784,  0.1685],\n",
      "         [-0.9504,  1.6038, -1.3030,  ...,  0.5754,  0.2677,  0.3343],\n",
      "         ...,\n",
      "         [-1.4757,  2.5317, -1.2321,  ...,  1.2997,  0.5019, -0.1034],\n",
      "         [-1.1731,  2.3172, -1.2542,  ...,  1.7391,  0.2171, -0.4445],\n",
      "         [-1.2700,  3.2229, -0.8872,  ...,  1.6461,  0.0973, -0.7679]],\n",
      "\n",
      "        [[-0.5873,  1.4291, -1.3950,  ...,  0.2102,  0.1027,  0.0918],\n",
      "         [ 0.1743,  1.7834, -1.6422,  ...,  0.8113,  0.3137,  0.5634],\n",
      "         [-0.3492,  1.8310, -1.0685,  ...,  0.6924,  0.1378,  0.4594],\n",
      "         ...,\n",
      "         [-1.0869,  2.3002, -1.2638,  ...,  1.7998,  0.5134, -0.5223],\n",
      "         [-1.2614,  2.7240, -1.3734,  ...,  1.4445,  0.5742, -0.3320],\n",
      "         [-2.2068,  4.3462, -3.8289,  ...,  2.1426,  1.2034, -1.3795]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[-0.3914,  1.8553, -0.5747,  ...,  1.0062,  0.4632, -1.0452],\n",
      "         [-0.8605,  2.0172, -1.4437,  ...,  1.4526,  0.1657,  0.5923],\n",
      "         [-0.7307,  2.2841, -1.0699,  ...,  1.5825, -0.0980,  0.5503],\n",
      "         ...,\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270],\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270],\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270]],\n",
      "\n",
      "        [[-0.1619,  0.6255, -1.1323,  ...,  0.0724, -0.2204,  0.4636],\n",
      "         [-0.0831,  0.5750, -1.0930,  ...,  0.9110, -0.0650,  0.7299],\n",
      "         [-0.2820,  0.0801, -0.9418,  ...,  0.3379, -0.1166,  0.4451],\n",
      "         ...,\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270],\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270],\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270]],\n",
      "\n",
      "        [[-0.5458, -0.6909, -1.3597,  ..., -0.7818,  0.6875,  0.9843],\n",
      "         [ 0.0421, -1.1062, -1.4389,  ..., -0.0239,  0.9115,  0.5287],\n",
      "         [-0.2909, -0.1886, -1.5487,  ..., -0.1392,  0.0580,  0.3066],\n",
      "         ...,\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270],\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270],\n",
      "         [-5.0821,  8.5920, -4.2137,  ...,  6.2693,  0.0539, -2.9270]]],\n",
      "       device='cuda:0', grad_fn=<MulBackward0>)\n",
      "tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,\n",
      "           0.0000e+00,  1.0000e+00],\n",
      "         [ 8.4147e-01,  5.4030e-01,  8.0196e-01,  ...,  1.0000e+00,\n",
      "           1.0746e-04,  1.0000e+00],\n",
      "         [ 9.0930e-01, -4.1615e-01,  9.5814e-01,  ...,  1.0000e+00,\n",
      "           2.1492e-04,  1.0000e+00],\n",
      "         ...,\n",
      "         [-7.6825e-01, -6.4014e-01,  6.3280e-01,  ...,  9.9998e-01,\n",
      "           5.1581e-03,  9.9999e-01],\n",
      "         [-9.5375e-01,  3.0059e-01,  9.9899e-01,  ...,  9.9998e-01,\n",
      "           5.2656e-03,  9.9999e-01],\n",
      "         [-2.6237e-01,  9.6497e-01,  5.6075e-01,  ...,  9.9998e-01,\n",
      "           5.3730e-03,  9.9999e-01]]], device='cuda:0')\n",
      "tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True, False, False, False,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True, False, False, False,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True, False, False, False,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True, False, False, False,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True, False, False, False, False, False,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True, False, False, False, False, False, False, False, False,\n",
      "          False]],\n",
      "\n",
      "        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "           True, False, False, False, False, False, False, False, False, False,\n",
      "          False]]], device='cuda:0')\n",
      "torch.Size([16, 1, 51])\n"
     ]
    }
   ],
   "source": [
    "print(xs)\n",
    "print(pos_emb)\n",
    "print(masks)\n",
    "print(masks.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "4289461b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[-0.54822     2.2866027  -1.0750197  ...  1.4503604   0.28950194\n",
      "   -0.6945408 ]\n",
      "  [-0.8012542   1.7687558  -1.6638877  ...  1.833158    0.6791494\n",
      "   -0.1999542 ]\n",
      "  [-1.7112465   2.7057455  -1.3363413  ...  1.2336441   0.18697014\n",
      "   -0.5735198 ]\n",
      "  ...\n",
      "  [-0.96968573  2.312949   -0.87524825 ...  0.85838526  0.4853347\n",
      "   -0.41773027]\n",
      "  [-1.3609431   2.1778803  -1.7812773  ...  2.0927877   0.25282228\n",
      "   -0.36496443]\n",
      "  [-1.6967483   2.3543842  -1.7416853  ...  1.366951    0.59511113\n",
      "   -0.74147725]]\n",
      "\n",
      " [[-1.9828408   2.31777    -0.9078527  ...  0.41170627  0.5006162\n",
      "    0.08721463]\n",
      "  [-0.76404583  1.3557773  -1.3612567  ...  0.7317046   0.678426\n",
      "    0.16851945]\n",
      "  [-0.95044655  1.6037656  -1.3029968  ...  0.57544005  0.26769355\n",
      "    0.33433008]\n",
      "  ...\n",
      "  [-1.475677    2.531713   -1.2320715  ...  1.2996731   0.50191855\n",
      "   -0.10343577]\n",
      "  [-1.1730809   2.3172235  -1.2542105  ...  1.7391105   0.21709818\n",
      "   -0.44447583]\n",
      "  [-1.2699623   3.2228963  -0.8871915  ...  1.6460502   0.09731755\n",
      "   -0.7678688 ]]\n",
      "\n",
      " [[-0.5872559   1.4290544  -1.3950099  ...  0.21024795  0.10272825\n",
      "    0.09179455]\n",
      "  [ 0.1742807   1.783423   -1.6421788  ...  0.8112701   0.31371105\n",
      "    0.56344515]\n",
      "  [-0.34916472  1.8310343  -1.0685117  ...  0.69243336  0.13782299\n",
      "    0.45937473]\n",
      "  ...\n",
      "  [-1.0868638   2.300204   -1.2638408  ...  1.7998282   0.5133892\n",
      "   -0.52227837]\n",
      "  [-1.2614481   2.7239661  -1.3733778  ...  1.444533    0.57420933\n",
      "   -0.33201432]\n",
      "  [-2.2067683   4.346218   -3.828867   ...  2.1426017   1.2033664\n",
      "   -1.3795122 ]]\n",
      "\n",
      " ...\n",
      "\n",
      " [[-0.39141566  1.8553346  -0.5747178  ...  1.0062351   0.46320182\n",
      "   -1.045236  ]\n",
      "  [-0.86054784  2.0171793  -1.4436853  ...  1.452623    0.16571884\n",
      "    0.5923172 ]\n",
      "  [-0.73066384  2.2840502  -1.0698992  ...  1.5824941  -0.0979555\n",
      "    0.55030036]\n",
      "  ...\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]]\n",
      "\n",
      " [[-0.16194311  0.6255052  -1.1323429  ...  0.07242929 -0.22042468\n",
      "    0.46362036]\n",
      "  [-0.08306468  0.575043   -1.09298    ...  0.9109665  -0.06501988\n",
      "    0.72986233]\n",
      "  [-0.28202093  0.08014385 -0.9417719  ...  0.3379485  -0.11664233\n",
      "    0.44514441]\n",
      "  ...\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]]\n",
      "\n",
      " [[-0.5458492  -0.69092435 -1.3596548  ... -0.78182435  0.68747747\n",
      "    0.9842716 ]\n",
      "  [ 0.04212743 -1.1061852  -1.438915   ... -0.02385022  0.91146135\n",
      "    0.52870303]\n",
      "  [-0.2909345  -0.18858244 -1.5487324  ... -0.13923697  0.05795169\n",
      "    0.30663735]\n",
      "  ...\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]]]\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "masks = ~make_pad_mask(feat_len).unsqueeze(1)  # (B, 1, L)\n",
    "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n",
    "print(xs.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "67e10d73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 2.0908e-03],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 1.1943e-02, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           4.6105e-02, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 9.6723e-03,\n",
      "           4.6135e-02, 0.0000e+00]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[2.2816e-01, 2.4615e-01, 2.5304e-01,  ..., 2.0402e-01,\n",
      "           2.3248e-01, 3.1191e-01],\n",
      "          [1.3587e-01, 2.8877e-01, 2.7991e-01,  ..., 1.9210e-01,\n",
      "           2.0346e-01, 1.9934e-01],\n",
      "          [2.5739e-01, 3.9348e-01, 2.7877e-01,  ..., 2.7483e-01,\n",
      "           1.9302e-01, 2.3810e-01],\n",
      "          ...,\n",
      "          [1.1939e-01, 2.8473e-01, 3.3082e-01,  ..., 2.3838e-01,\n",
      "           2.2104e-01, 2.3906e-01],\n",
      "          [1.7388e-01, 2.0402e-01, 4.0263e-01,  ..., 2.4782e-01,\n",
      "           2.6742e-01, 1.5427e-01],\n",
      "          [0.0000e+00, 2.9081e-01, 2.7726e-01,  ..., 1.7540e-01,\n",
      "           1.8479e-01, 2.2483e-01]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[3.5447e-01, 3.8861e-01, 3.9724e-01,  ..., 3.8680e-01,\n",
      "           3.3568e-01, 3.4552e-01],\n",
      "          [4.1739e-01, 5.1039e-01, 4.1730e-01,  ..., 3.3993e-01,\n",
      "           3.7082e-01, 3.5110e-01],\n",
      "          [3.6117e-01, 4.0745e-01, 4.8491e-01,  ..., 3.4849e-01,\n",
      "           3.2321e-01, 3.5189e-01],\n",
      "          ...,\n",
      "          [2.3144e-01, 3.8021e-01, 5.1526e-01,  ..., 3.6499e-01,\n",
      "           3.7412e-01, 3.9986e-01],\n",
      "          [3.4679e-01, 4.0238e-01, 5.0077e-01,  ..., 3.6185e-01,\n",
      "           3.1597e-01, 3.6335e-01],\n",
      "          [3.6498e-01, 3.7943e-01, 5.1719e-01,  ..., 3.1798e-01,\n",
      "           3.3657e-01, 3.4130e-01]]],\n",
      "\n",
      "\n",
      "        [[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[1.4560e-02, 9.4475e-02, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [1.5002e-02, 2.9632e-02, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [3.2952e-02, 0.0000e+00, 0.0000e+00,  ..., 4.5850e-02,\n",
      "           2.0439e-02, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 4.4258e-02],\n",
      "          [0.0000e+00, 0.0000e+00, 2.5565e-02,  ..., 0.0000e+00,\n",
      "           9.0044e-03, 4.9084e-02]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [1.1141e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[3.3697e-01, 3.8527e-01, 3.2900e-01,  ..., 2.8704e-01,\n",
      "           2.3351e-01, 1.9004e-01],\n",
      "          [1.3575e-01, 3.5783e-01, 3.3573e-01,  ..., 2.2082e-01,\n",
      "           1.5855e-01, 1.3587e-01],\n",
      "          [2.1929e-01, 2.8900e-01, 2.8255e-01,  ..., 2.0603e-01,\n",
      "           2.3927e-01, 2.1909e-01],\n",
      "          ...,\n",
      "          [2.3292e-01, 3.9097e-01, 3.6399e-01,  ..., 2.0598e-01,\n",
      "           2.5374e-01, 2.3137e-01],\n",
      "          [1.8739e-01, 3.0794e-01, 3.0297e-01,  ..., 2.7251e-01,\n",
      "           2.5192e-01, 2.0837e-01],\n",
      "          [2.2454e-01, 4.1402e-01, 5.4083e-01,  ..., 3.1875e-01,\n",
      "           2.5080e-01, 2.5939e-01]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[2.6457e-01, 4.9519e-01, 5.6702e-01,  ..., 3.0955e-01,\n",
      "           3.5292e-01, 3.2669e-01],\n",
      "          [2.1577e-01, 5.1833e-01, 4.9183e-01,  ..., 3.6043e-01,\n",
      "           3.8524e-01, 3.6155e-01],\n",
      "          [2.0068e-01, 4.2784e-01, 5.2818e-01,  ..., 3.1871e-01,\n",
      "           3.2452e-01, 3.1036e-01],\n",
      "          ...,\n",
      "          [4.9855e-01, 5.1001e-01, 5.2279e-01,  ..., 3.6450e-01,\n",
      "           3.4338e-01, 3.3603e-01],\n",
      "          [4.1233e-01, 5.5518e-01, 5.2828e-01,  ..., 4.0676e-01,\n",
      "           3.3873e-01, 3.6724e-01],\n",
      "          [4.0820e-01, 4.6187e-01, 4.7338e-01,  ..., 3.8691e-01,\n",
      "           3.6039e-01, 3.8022e-01]]],\n",
      "\n",
      "\n",
      "        [[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[0.0000e+00, 5.7852e-03, 0.0000e+00,  ..., 7.4838e-03,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.0351e-02,\n",
      "           0.0000e+00, 2.6720e-04],\n",
      "          [9.4807e-04, 0.0000e+00, 0.0000e+00,  ..., 7.9551e-03,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [2.0326e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           1.0801e-02, 0.0000e+00],\n",
      "          [1.8470e-01, 0.0000e+00, 0.0000e+00,  ..., 5.0584e-02,\n",
      "           9.4758e-02, 5.9146e-02]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[3.8708e-01, 2.8022e-01, 3.5893e-01,  ..., 1.6595e-01,\n",
      "           1.6031e-01, 2.1136e-01],\n",
      "          [1.5595e-01, 3.0544e-01, 2.4666e-01,  ..., 2.2675e-01,\n",
      "           2.5765e-01, 1.9682e-01],\n",
      "          [2.9518e-01, 4.1210e-01, 2.0063e-01,  ..., 1.7595e-01,\n",
      "           2.2537e-01, 2.2214e-01],\n",
      "          ...,\n",
      "          [2.4745e-01, 2.6259e-01, 3.8654e-01,  ..., 2.3620e-01,\n",
      "           2.3157e-01, 1.8514e-01],\n",
      "          [2.5715e-01, 2.9593e-01, 4.7745e-01,  ..., 2.3546e-01,\n",
      "           2.5073e-01, 2.0976e-01],\n",
      "          [1.2015e+00, 8.4644e-01, 7.3386e-01,  ..., 1.0252e+00,\n",
      "           9.5310e-01, 1.0013e+00]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[4.5013e-01, 4.7484e-01, 4.0540e-01,  ..., 1.9346e-01,\n",
      "           1.7826e-01, 1.4777e-01],\n",
      "          [4.7546e-01, 4.8187e-01, 3.6760e-01,  ..., 2.7809e-01,\n",
      "           3.2997e-01, 3.2337e-01],\n",
      "          [4.6160e-01, 4.0050e-01, 3.9061e-01,  ..., 3.6613e-01,\n",
      "           3.5243e-01, 2.9739e-01],\n",
      "          ...,\n",
      "          [5.5148e-01, 5.1018e-01, 4.0132e-01,  ..., 3.8948e-01,\n",
      "           3.5737e-01, 3.3088e-01],\n",
      "          [4.1973e-01, 4.5475e-01, 4.5320e-01,  ..., 3.8343e-01,\n",
      "           4.0126e-01, 3.6181e-01],\n",
      "          [3.4280e-01, 3.1606e-01, 4.4701e-01,  ..., 2.1665e-01,\n",
      "           2.3985e-01, 2.3903e-01]]],\n",
      "\n",
      "\n",
      "        ...,\n",
      "\n",
      "\n",
      "        [[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [4.1783e-02, 0.0000e+00, 1.5805e-02,  ..., 0.0000e+00,\n",
      "           2.2508e-02, 0.0000e+00],\n",
      "          [4.3234e-02, 7.7864e-02, 0.0000e+00,  ..., 1.6347e-02,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[3.2092e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [1.3563e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0.0000e+00, 2.5187e-01, 2.4979e-01,  ..., 2.4775e-01,\n",
      "           2.2354e-01, 1.9149e-01],\n",
      "          [1.6541e-01, 1.9586e-01, 1.9813e-01,  ..., 2.7344e-01,\n",
      "           2.0928e-01, 2.6150e-01],\n",
      "          [1.0495e-01, 6.3299e-02, 3.3844e-01,  ..., 2.5138e-01,\n",
      "           1.2470e-01, 2.3927e-01],\n",
      "          ...,\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00],\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00],\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[1.1428e-01, 4.5667e-01, 4.6821e-01,  ..., 3.2058e-01,\n",
      "           3.3579e-01, 3.9013e-01],\n",
      "          [1.0441e-01, 4.5739e-01, 4.6107e-01,  ..., 3.8468e-01,\n",
      "           3.8291e-01, 3.6686e-01],\n",
      "          [1.9868e-01, 3.5520e-01, 4.4313e-01,  ..., 4.0679e-01,\n",
      "           3.8068e-01, 3.0646e-01],\n",
      "          ...,\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00],\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00],\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00]]],\n",
      "\n",
      "\n",
      "        [[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[2.4654e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 3.3902e-02],\n",
      "          [0.0000e+00, 0.0000e+00, 1.8307e-02,  ..., 5.1669e-02,\n",
      "           9.4838e-03, 7.4535e-02],\n",
      "          [9.9215e-02, 0.0000e+00, 1.5872e-02,  ..., 1.6203e-02,\n",
      "           5.1401e-02, 1.9239e-03],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[4.0034e-01, 2.5306e-01, 2.0218e-01,  ..., 9.8162e-02,\n",
      "           7.0643e-02, 4.9741e-02],\n",
      "          [1.2568e-01, 2.1031e-01, 1.1182e-01,  ..., 4.2781e-02,\n",
      "           1.1969e-01, 1.2005e-01],\n",
      "          [2.8787e-01, 2.4031e-01, 2.2566e-01,  ..., 0.0000e+00,\n",
      "           6.4181e-02, 5.8730e-02],\n",
      "          ...,\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00],\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00],\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[3.8405e-01, 3.0990e-01, 3.7156e-01,  ..., 1.8125e-01,\n",
      "           1.5051e-01, 1.9620e-01],\n",
      "          [4.7286e-01, 4.0529e-01, 3.9718e-01,  ..., 2.4710e-01,\n",
      "           4.5657e-02, 1.1501e-01],\n",
      "          [3.2621e-01, 3.0073e-01, 3.0477e-01,  ..., 2.3529e-01,\n",
      "           2.1357e-01, 1.6986e-01],\n",
      "          ...,\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00],\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00],\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00]]],\n",
      "\n",
      "\n",
      "        [[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[3.3438e-02, 1.2378e-03, 5.2972e-02,  ..., 7.2712e-02,\n",
      "           8.6563e-02, 1.4494e-01],\n",
      "          [1.1043e-01, 6.1431e-02, 6.3630e-02,  ..., 8.1278e-02,\n",
      "           6.2590e-02, 8.3154e-02],\n",
      "          [1.7677e-02, 2.0111e-03, 7.8750e-02,  ..., 6.9633e-02,\n",
      "           8.9799e-02, 5.3263e-02],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[1.0034e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [1.5627e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [5.1447e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 4.3641e-03],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[2.5142e-01, 4.5964e-01, 3.7346e-01,  ..., 4.7631e-02,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [1.9760e-01, 2.6627e-01, 1.1191e-01,  ..., 3.0450e-02,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [1.6341e-01, 3.2938e-01, 2.5690e-01,  ..., 5.5694e-02,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00],\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00],\n",
      "          [1.1257e+00, 8.7341e-01, 7.8169e-01,  ..., 1.0458e+00,\n",
      "           1.0094e+00, 1.0221e+00]],\n",
      "\n",
      "         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           2.2189e-02, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 2.8490e-02],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          ...,\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00],\n",
      "          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n",
      "           0.0000e+00, 0.0000e+00]],\n",
      "\n",
      "         [[2.5810e-01, 6.3017e-01, 3.7038e-01,  ..., 1.8704e-01,\n",
      "           8.2694e-02, 9.9127e-02],\n",
      "          [1.7293e-01, 5.0679e-01, 4.0739e-01,  ..., 1.6006e-01,\n",
      "           1.1725e-01, 9.9405e-02],\n",
      "          [2.4175e-01, 4.1616e-01, 4.1257e-01,  ..., 1.3520e-01,\n",
      "           7.9126e-02, 1.2846e-01],\n",
      "          ...,\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00],\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00],\n",
      "          [1.4488e+00, 1.0212e+00, 9.4473e-01,  ..., 1.2363e+00,\n",
      "           1.2189e+00, 1.2380e+00]]]], device='cuda:0',\n",
      "       grad_fn=<ReluBackward0>)\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "masks = ~make_pad_mask(feat_len).unsqueeze(1)  # (B, 1, L)\n",
    "\n",
    "x = xs.unsqueeze(1)\n",
    "x = model.encoder.embed.conv(x)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "9a9478ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[-0.03426375  0.14291267 -0.06718873 ...  0.09064753  0.01809387\n",
      "   -0.0434088 ]\n",
      "  [-0.05007839  0.11054724 -0.10399298 ...  0.11457238  0.04244684\n",
      "   -0.01249714]\n",
      "  [-0.10695291  0.16910909 -0.08352133 ...  0.07710276  0.01168563\n",
      "   -0.03584499]\n",
      "  ...\n",
      "  [-0.06060536  0.14455931 -0.05470302 ...  0.05364908  0.03033342\n",
      "   -0.02610814]\n",
      "  [-0.08505894  0.13611752 -0.11132983 ...  0.13079923  0.01580139\n",
      "   -0.02281028]\n",
      "  [-0.10604677  0.14714901 -0.10885533 ...  0.08543444  0.03719445\n",
      "   -0.04634233]]\n",
      "\n",
      " [[-0.12392755  0.14486063 -0.05674079 ...  0.02573164  0.03128851\n",
      "    0.00545091]\n",
      "  [-0.04775286  0.08473608 -0.08507854 ...  0.04573154  0.04240163\n",
      "    0.01053247]\n",
      "  [-0.05940291  0.10023535 -0.0814373  ...  0.035965    0.01673085\n",
      "    0.02089563]\n",
      "  ...\n",
      "  [-0.09222981  0.15823206 -0.07700447 ...  0.08122957  0.03136991\n",
      "   -0.00646474]\n",
      "  [-0.07331756  0.14482647 -0.07838815 ...  0.1086944   0.01356864\n",
      "   -0.02777974]\n",
      "  [-0.07937264  0.20143102 -0.05544947 ...  0.10287814  0.00608235\n",
      "   -0.0479918 ]]\n",
      "\n",
      " [[-0.03670349  0.0893159  -0.08718812 ...  0.0131405   0.00642052\n",
      "    0.00573716]\n",
      "  [ 0.01089254  0.11146393 -0.10263617 ...  0.05070438  0.01960694\n",
      "    0.03521532]\n",
      "  [-0.0218228   0.11443964 -0.06678198 ...  0.04327708  0.00861394\n",
      "    0.02871092]\n",
      "  ...\n",
      "  [-0.06792898  0.14376275 -0.07899005 ...  0.11248926  0.03208683\n",
      "   -0.0326424 ]\n",
      "  [-0.07884051  0.17024788 -0.08583611 ...  0.09028331  0.03588808\n",
      "   -0.0207509 ]\n",
      "  [-0.13792302  0.27163863 -0.23930418 ...  0.13391261  0.0752104\n",
      "   -0.08621951]]\n",
      "\n",
      " ...\n",
      "\n",
      " [[-0.02446348  0.11595841 -0.03591986 ...  0.0628897   0.02895011\n",
      "   -0.06532725]\n",
      "  [-0.05378424  0.1260737  -0.09023033 ...  0.09078894  0.01035743\n",
      "    0.03701983]\n",
      "  [-0.04566649  0.14275314 -0.0668687  ...  0.09890588 -0.00612222\n",
      "    0.03439377]\n",
      "  ...\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]]\n",
      "\n",
      " [[-0.01012144  0.03909408 -0.07077143 ...  0.00452683 -0.01377654\n",
      "    0.02897627]\n",
      "  [-0.00519154  0.03594019 -0.06831125 ...  0.05693541 -0.00406374\n",
      "    0.0456164 ]\n",
      "  [-0.01762631  0.00500899 -0.05886075 ...  0.02112178 -0.00729015\n",
      "    0.02782153]\n",
      "  ...\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]]\n",
      "\n",
      " [[-0.03411558 -0.04318277 -0.08497842 ... -0.04886402  0.04296734\n",
      "    0.06151697]\n",
      "  [ 0.00263296 -0.06913657 -0.08993219 ... -0.00149064  0.05696633\n",
      "    0.03304394]\n",
      "  [-0.01818341 -0.0117864  -0.09679577 ... -0.00870231  0.00362198\n",
      "    0.01916483]\n",
      "  ...\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]\n",
      "  [-0.31763062  0.5370021  -0.2633542  ...  0.39182857  0.00337184\n",
      "   -0.18293698]]]\n",
      "torch.Size([16, 51, 256])\n"
     ]
    }
   ],
   "source": [
    "b, c, t, f = x.size()\n",
    "x = model.encoder.embed.out(x.transpose(1, 2).contiguous().view(b, t, c * f))\n",
    "print(x.cpu().detach().numpy())\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "fd69003f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[-0.54822     2.2866027  -1.0750197  ...  1.4503604   0.28950194\n",
      "   -0.6945408 ]\n",
      "  [-0.8012542   1.7687558  -1.6638877  ...  1.833158    0.6791494\n",
      "   -0.1999542 ]\n",
      "  [-1.7112465   2.7057455  -1.3363413  ...  1.2336441   0.18697014\n",
      "   -0.5735198 ]\n",
      "  ...\n",
      "  [-0.96968573  2.312949   -0.87524825 ...  0.85838526  0.4853347\n",
      "   -0.41773027]\n",
      "  [-1.3609431   2.1778803  -1.7812773  ...  2.0927877   0.25282228\n",
      "   -0.36496443]\n",
      "  [-1.6967483   2.3543842  -1.7416853  ...  1.366951    0.59511113\n",
      "   -0.74147725]]\n",
      "\n",
      " [[-1.9828408   2.31777    -0.9078527  ...  0.41170627  0.5006162\n",
      "    0.08721463]\n",
      "  [-0.76404583  1.3557773  -1.3612567  ...  0.7317046   0.678426\n",
      "    0.16851945]\n",
      "  [-0.95044655  1.6037656  -1.3029968  ...  0.57544005  0.26769355\n",
      "    0.33433008]\n",
      "  ...\n",
      "  [-1.475677    2.531713   -1.2320715  ...  1.2996731   0.50191855\n",
      "   -0.10343577]\n",
      "  [-1.1730809   2.3172235  -1.2542105  ...  1.7391105   0.21709818\n",
      "   -0.44447583]\n",
      "  [-1.2699623   3.2228963  -0.8871915  ...  1.6460502   0.09731755\n",
      "   -0.7678688 ]]\n",
      "\n",
      " [[-0.5872559   1.4290544  -1.3950099  ...  0.21024795  0.10272825\n",
      "    0.09179455]\n",
      "  [ 0.1742807   1.783423   -1.6421788  ...  0.8112701   0.31371105\n",
      "    0.56344515]\n",
      "  [-0.34916472  1.8310343  -1.0685117  ...  0.69243336  0.13782299\n",
      "    0.45937473]\n",
      "  ...\n",
      "  [-1.0868638   2.300204   -1.2638408  ...  1.7998282   0.5133892\n",
      "   -0.52227837]\n",
      "  [-1.2614481   2.7239661  -1.3733778  ...  1.444533    0.57420933\n",
      "   -0.33201432]\n",
      "  [-2.2067683   4.346218   -3.828867   ...  2.1426017   1.2033664\n",
      "   -1.3795122 ]]\n",
      "\n",
      " ...\n",
      "\n",
      " [[-0.39141566  1.8553346  -0.5747178  ...  1.0062351   0.46320182\n",
      "   -1.045236  ]\n",
      "  [-0.86054784  2.0171793  -1.4436853  ...  1.452623    0.16571884\n",
      "    0.5923172 ]\n",
      "  [-0.73066384  2.2840502  -1.0698992  ...  1.5824941  -0.0979555\n",
      "    0.55030036]\n",
      "  ...\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]]\n",
      "\n",
      " [[-0.16194311  0.6255052  -1.1323429  ...  0.07242929 -0.22042468\n",
      "    0.46362036]\n",
      "  [-0.08306468  0.575043   -1.09298    ...  0.9109665  -0.06501988\n",
      "    0.72986233]\n",
      "  [-0.28202093  0.08014385 -0.9417719  ...  0.3379485  -0.11664233\n",
      "    0.44514441]\n",
      "  ...\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]]\n",
      "\n",
      " [[-0.5458492  -0.69092435 -1.3596548  ... -0.78182435  0.68747747\n",
      "    0.9842716 ]\n",
      "  [ 0.04212743 -1.1061852  -1.438915   ... -0.02385022  0.91146135\n",
      "    0.52870303]\n",
      "  [-0.2909345  -0.18858244 -1.5487324  ... -0.13923697  0.05795169\n",
      "    0.30663735]\n",
      "  ...\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]\n",
      "  [-5.08209     8.592033   -4.2136674  ...  6.269257    0.05394945\n",
      "   -2.9269917 ]]]\n"
     ]
    }
   ],
   "source": [
    "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n",
    "print(x.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "8ed88489",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n",
      "[[[ 0.0000000e+00  1.0000000e+00  0.0000000e+00 ...  1.0000000e+00\n",
      "    0.0000000e+00  1.0000000e+00]\n",
      "  [ 8.4147096e-01  5.4030234e-01  8.0196178e-01 ...  1.0000000e+00\n",
      "    1.0746076e-04  1.0000000e+00]\n",
      "  [ 9.0929741e-01 -4.1614684e-01  9.5814437e-01 ...  1.0000000e+00\n",
      "    2.1492151e-04  1.0000000e+00]\n",
      "  ...\n",
      "  [-7.6825464e-01 -6.4014435e-01  6.3279724e-01 ...  9.9998462e-01\n",
      "    5.1580933e-03  9.9998671e-01]\n",
      "  [-9.5375264e-01  3.0059254e-01  9.9899054e-01 ...  9.9998397e-01\n",
      "    5.2655530e-03  9.9998611e-01]\n",
      "  [-2.6237485e-01  9.6496606e-01  5.6074661e-01 ...  9.9998331e-01\n",
      "    5.3730118e-03  9.9998558e-01]]]\n"
     ]
    }
   ],
   "source": [
    "print(pos_emb.dtype)\n",
    "print(pos_emb.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "5e277881",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 51, 256])\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'mask' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-54-bcc891cacca8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m    141\u001b[0m              \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    142\u001b[0m              \u001b[0mpos_emb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpos_emb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m              \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    144\u001b[0m              \u001b[0mx_att\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    145\u001b[0m             )\n",
      "\u001b[0;31mNameError\u001b[0m: name 'mask' is not defined"
     ]
    }
   ],
   "source": [
    "def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,\n",
    "                            use_dynamic_chunk: bool,\n",
    "                            use_dynamic_left_chunk: bool,\n",
    "                            decoding_chunk_size: int, static_chunk_size: int,\n",
    "                            num_decoding_left_chunks: int):\n",
    "    \"\"\" Apply optional mask for encoder.\n",
    "    Args:\n",
    "        xs (torch.Tensor): padded input, (B, L, D), L for max length\n",
    "        mask (torch.Tensor): mask for xs, (B, 1, L)\n",
    "        use_dynamic_chunk (bool): whether to use dynamic chunk or not\n",
    "        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for\n",
    "            training.\n",
    "        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's\n",
    "            0: default for training, use random dynamic chunk.\n",
    "            <0: for decoding, use full chunk.\n",
    "            >0: for decoding, use fixed chunk size as set.\n",
    "        static_chunk_size (int): chunk size for static chunk training/decoding\n",
    "            if it's greater than 0, if use_dynamic_chunk is true,\n",
    "            this parameter will be ignored\n",
    "        num_decoding_left_chunks: number of left chunks, this is for decoding,\n",
    "            the chunk size is decoding_chunk_size.\n",
    "            >=0: use num_decoding_left_chunks\n",
    "            <0: use all left chunks\n",
    "    Returns:\n",
    "        torch.Tensor: chunk mask of the input xs.\n",
    "    \"\"\"\n",
    "    # Whether to use chunk mask or not\n",
    "    if use_dynamic_chunk:\n",
    "        max_len = xs.size(1)\n",
    "        if decoding_chunk_size < 0:\n",
    "            chunk_size = max_len\n",
    "            num_left_chunks = -1\n",
    "        elif decoding_chunk_size > 0:\n",
    "            chunk_size = decoding_chunk_size\n",
    "            num_left_chunks = num_decoding_left_chunks\n",
    "        else:\n",
    "            # chunk size is either [1, 25] or full context(max_len).\n",
    "            # Since we use 4 times subsampling and allow up to 1s(100 frames)\n",
    "            # delay, the maximum frame is 100 / 4 = 25.\n",
    "            chunk_size = torch.randint(1, max_len, (1, )).item()\n",
    "            num_left_chunks = -1\n",
    "            if chunk_size > max_len // 2:\n",
    "                chunk_size = max_len\n",
    "            else:\n",
    "                chunk_size = chunk_size % 25 + 1\n",
    "                if use_dynamic_left_chunk:\n",
    "                    max_left_chunks = (max_len - 1) // chunk_size\n",
    "                    num_left_chunks = torch.randint(0, max_left_chunks,\n",
    "                                                    (1, )).item()\n",
    "        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,\n",
    "                                            num_left_chunks,\n",
    "                                            xs.device)  # (L, L)\n",
    "        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)\n",
    "        chunk_masks = masks & chunk_masks  # (B, L, L)\n",
    "    elif static_chunk_size > 0:\n",
    "        num_left_chunks = num_decoding_left_chunks\n",
    "        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,\n",
    "                                            num_left_chunks,\n",
    "                                            xs.device)  # (L, L)\n",
    "        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)\n",
    "        chunk_masks = masks & chunk_masks  # (B, L, L)\n",
    "    else:\n",
    "        chunk_masks = masks\n",
    "    return chunk_masks\n",
    "\n",
    "from wenet.utils.mask import make_pad_mask\n",
    "\n",
    "\n",
    "masks = ~make_pad_mask(feat_len).unsqueeze(1)\n",
    "xs = model.encoder.global_cmvn(feat)\n",
    "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n",
    "\n",
    "mask_pad = masks\n",
    "decoding_chunk_size=0\n",
    "num_decoding_left_chunks=-1\n",
    "use_dynamic_left_chunk=-1\n",
    "use_dynamic_chunk=False\n",
    "static_chunk_size=-1\n",
    "chunk_masks = add_optional_chunk_mask(\n",
    "            xs, \n",
    "            masks, \n",
    "            use_dynamic_chunk,\n",
    "            use_dynamic_left_chunk,\n",
    "            decoding_chunk_size, \n",
    "            static_chunk_size,\n",
    "            num_decoding_left_chunks)\n",
    "\n",
    "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_embed', \n",
    "         embed_out=xs.cpu().detach().numpy(), \n",
    "         pos_emb=pos_emb.cpu().detach().numpy(),\n",
    "        chunk_masks=chunk_masks.cpu().detach().numpy(),\n",
    "        mask_pad=mask_pad.cpu().detach().numpy())\n",
    "\n",
    "model.eval()\n",
    "# print(chunk_masks)\n",
    "print(xs.shape)\n",
    "for layer in model.encoder.encoders:\n",
    "    #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
    "    #np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0', enc_0=xs.cpu().detach().numpy())\n",
    "    \n",
    "    x = xs\n",
    "    residual = x\n",
    "    x_norm = layer.norm_ff_macaron(x)\n",
    "    !rm /workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff.npz\n",
    "    np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff', \n",
    "             norm_ff=x_norm.cpu().detach().numpy(),\n",
    "            xs=xs.cpu().detach().numpy())\n",
    "    #print(x.cpu().detach().numpy())\n",
    "    for p in layer.norm_ff_macaron.parameters():\n",
    "        #print(p, p.sum())\n",
    "        pass\n",
    "    \n",
    "    x = residual + layer.ff_scale * layer.feed_forward_macaron(x_norm)\n",
    "    \n",
    "    ps = []\n",
    "    for n, p in layer.feed_forward_macaron.state_dict().items():\n",
    "      #print(n, p.cpu().data.numpy())\n",
    "      ps.append(p.cpu().data.numpy())\n",
    "      pass\n",
    "\n",
    "    ff_l_x = layer.feed_forward_macaron.w_1(x_norm)\n",
    "    ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n",
    "    ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n",
    "    np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_ff_out', \n",
    "             norm_ff=x_norm.cpu().detach().numpy(),\n",
    "             ff_out=x.cpu().detach().numpy(),\n",
    "             ff_l_x = ff_l_x.cpu().detach().numpy(),\n",
    "             ff_l_a_x=ff_l_a_x.cpu().detach().numpy(),\n",
    "             ff_l_a_l_x=ff_l_a_l_x.cpu().detach().numpy(),\n",
    "             ps=ps,\n",
    "            )\n",
    "    \n",
    "    \n",
    "    residual = x\n",
    "    x = layer.norm_mha(x)\n",
    "    x_q = x\n",
    "    \n",
    "    x_att = layer.self_attn(x_q, x, x, pos_emb, masks)\n",
    "    np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_selattn_out', \n",
    "             x_q=x_q.cpu().detach().numpy(),\n",
    "             x=x.cpu().detach().numpy(),\n",
    "             pos_emb = pos_emb.cpu().detach().numpy(),\n",
    "             mask=mask.cpu().detach().numpy(),\n",
    "             x_att=x_att.cpu().detach().numpy(),\n",
    "            )\n",
    "    \n",
    "    break\n",
    "#print(xs.cpu().detach().numpy())\n",
    "\n",
    "\n",
    "i = 0\n",
    "for layer in model.encoder.encoders:\n",
    "    xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
    "    i += 1\n",
    "    if i == 2:\n",
    "        np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_2', enc_2=xs.cpu().detach().numpy())\n",
    "        \n",
    "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_all', enc_all=xs.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c43fd4f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "out, mask = model.encoder(feat, feat_len)\n",
    "#print(out.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e73db22",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f506114",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}