{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "choice-grade",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/DeepSpeech-2.x\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'/workspace/DeepSpeech-2.x'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%cd ..\n",
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "broke-broad",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  def convert_to_list(value, n, name, dtype=np.int):\n",
      "register user softmax to paddle, remove this when fixed!\n",
      "register user log_softmax to paddle, remove this when fixed!\n",
      "register user sigmoid to paddle, remove this when fixed!\n",
      "register user log_sigmoid to paddle, remove this when fixed!\n",
      "register user relu to paddle, remove this when fixed!\n",
      "override cat of paddle if exists or register, remove this when fixed!\n",
      "override item of paddle.Tensor if exists or register, remove this when fixed!\n",
      "override long of paddle.Tensor if exists or register, remove this when fixed!\n",
      "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
      "override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
      "override eq of paddle if exists or register, remove this when fixed!\n",
      "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
      "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
      "register user view to paddle.Tensor, remove this when fixed!\n",
      "register user view_as to paddle.Tensor, remove this when fixed!\n",
      "register user masked_fill to paddle.Tensor, remove this when fixed!\n",
      "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
      "register user fill_ to paddle.Tensor, remove this when fixed!\n",
      "register user repeat to paddle.Tensor, remove this when fixed!\n",
      "register user softmax to paddle.Tensor, remove this when fixed!\n",
      "register user sigmoid to paddle.Tensor, remove this when fixed!\n",
      "register user relu to paddle.Tensor, remove this when fixed!\n",
      "register user type_as to paddle.Tensor, remove this when fixed!\n",
      "register user to to paddle.Tensor, remove this when fixed!\n",
      "register user float to paddle.Tensor, remove this when fixed!\n",
      "register user tolist to paddle.Tensor, remove this when fixed!\n",
      "register user glu to paddle.nn.functional, remove this when fixed!\n",
      "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
      "register user Module to paddle.nn, remove this when fixed!\n",
      "register user ModuleList to paddle.nn, remove this when fixed!\n",
      "register user GLU to paddle.nn, remove this when fixed!\n",
      "register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
      "register user export to paddle.jit, remove this when fixed!\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import paddle\n",
    "from yacs.config import CfgNode as CN\n",
    "\n",
    "from deepspeech.models.u2 import U2Model\n",
    "from deepspeech.utils.layer_tools import print_params\n",
    "from deepspeech.utils.layer_tools import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "permanent-summary",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n",
      "[INFO 2021/04/20 03:32:21 u2.py:834] U2 Encoder type: conformer\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n",
      "encoder.embed.conv.0.bias | [256] | 256 | True\n",
      "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n",
      "encoder.embed.conv.2.bias | [256] | 256 | True\n",
      "encoder.embed.out.0.weight | [4864, 256] | 1245184 | True\n",
      "encoder.embed.out.0.bias | [256] | 256 | True\n",
      "encoder.after_norm.weight | [256] | 256 | True\n",
      "encoder.after_norm.bias | [256] | 256 | True\n",
      "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.0.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.0.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.0.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.0.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.0.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.0.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.0.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.0.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.0.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.0.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.0.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.0.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.1.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.1.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.1.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.1.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.1.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.1.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.1.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.1.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.1.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.1.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.1.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.1.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.2.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.2.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.2.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.2.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.2.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.2.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.2.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.2.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.2.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.2.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.2.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.2.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.3.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.3.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.3.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.3.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.3.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.3.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.3.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.3.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.3.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.3.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.3.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.3.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.4.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.4.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.4.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.4.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.4.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.4.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.4.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.4.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.4.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.4.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.4.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.4.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.5.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.5.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.5.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.5.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.5.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.5.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.5.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.5.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.5.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.5.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.5.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.5.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.6.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.6.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.6.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.6.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.6.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.6.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.6.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.6.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.6.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.6.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.6.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.6.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.7.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.7.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.7.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.7.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.7.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.7.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.7.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.7.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.7.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.7.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.7.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.7.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.8.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.8.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.8.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.8.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.8.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.8.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.8.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.8.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.8.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.8.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.8.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.8.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.9.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.9.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.9.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.9.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.9.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.9.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.9.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.9.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.9.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.9.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.9.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.9.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.10.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.10.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.10.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.10.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.10.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.10.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.10.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.10.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.10.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.10.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.10.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.10.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n",
      "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
      "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n",
      "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
      "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
      "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
      "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
      "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
      "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.11.conv_module.norm.weight | [256] | 256 | True\n",
      "encoder.encoders.11.conv_module.norm.bias | [256] | 256 | True\n",
      "encoder.encoders.11.conv_module.norm._mean | [256] | 256 | False\n",
      "encoder.encoders.11.conv_module.norm._variance | [256] | 256 | False\n",
      "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
      "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
      "encoder.encoders.11.norm_ff.weight | [256] | 256 | True\n",
      "encoder.encoders.11.norm_ff.bias | [256] | 256 | True\n",
      "encoder.encoders.11.norm_mha.weight | [256] | 256 | True\n",
      "encoder.encoders.11.norm_mha.bias | [256] | 256 | True\n",
      "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256 | True\n",
      "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256 | True\n",
      "encoder.encoders.11.norm_conv.weight | [256] | 256 | True\n",
      "encoder.encoders.11.norm_conv.bias | [256] | 256 | True\n",
      "encoder.encoders.11.norm_final.weight | [256] | 256 | True\n",
      "encoder.encoders.11.norm_final.bias | [256] | 256 | True\n",
      "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n",
      "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n",
      "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n",
      "decoder.after_norm.weight | [256] | 256 | True\n",
      "decoder.after_norm.bias | [256] | 256 | True\n",
      "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n",
      "decoder.output_layer.bias | [4233] | 4233 | True\n",
      "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
      "decoder.decoders.0.norm1.weight | [256] | 256 | True\n",
      "decoder.decoders.0.norm1.bias | [256] | 256 | True\n",
      "decoder.decoders.0.norm2.weight | [256] | 256 | True\n",
      "decoder.decoders.0.norm2.bias | [256] | 256 | True\n",
      "decoder.decoders.0.norm3.weight | [256] | 256 | True\n",
      "decoder.decoders.0.norm3.bias | [256] | 256 | True\n",
      "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n",
      "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n",
      "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
      "decoder.decoders.1.norm1.weight | [256] | 256 | True\n",
      "decoder.decoders.1.norm1.bias | [256] | 256 | True\n",
      "decoder.decoders.1.norm2.weight | [256] | 256 | True\n",
      "decoder.decoders.1.norm2.bias | [256] | 256 | True\n",
      "decoder.decoders.1.norm3.weight | [256] | 256 | True\n",
      "decoder.decoders.1.norm3.bias | [256] | 256 | True\n",
      "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n",
      "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n",
      "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
      "decoder.decoders.2.norm1.weight | [256] | 256 | True\n",
      "decoder.decoders.2.norm1.bias | [256] | 256 | True\n",
      "decoder.decoders.2.norm2.weight | [256] | 256 | True\n",
      "decoder.decoders.2.norm2.bias | [256] | 256 | True\n",
      "decoder.decoders.2.norm3.weight | [256] | 256 | True\n",
      "decoder.decoders.2.norm3.bias | [256] | 256 | True\n",
      "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n",
      "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n",
      "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
      "decoder.decoders.3.norm1.weight | [256] | 256 | True\n",
      "decoder.decoders.3.norm1.bias | [256] | 256 | True\n",
      "decoder.decoders.3.norm2.weight | [256] | 256 | True\n",
      "decoder.decoders.3.norm2.bias | [256] | 256 | True\n",
      "decoder.decoders.3.norm3.weight | [256] | 256 | True\n",
      "decoder.decoders.3.norm3.bias | [256] | 256 | True\n",
      "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n",
      "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n",
      "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
      "decoder.decoders.4.norm1.weight | [256] | 256 | True\n",
      "decoder.decoders.4.norm1.bias | [256] | 256 | True\n",
      "decoder.decoders.4.norm2.weight | [256] | 256 | True\n",
      "decoder.decoders.4.norm2.bias | [256] | 256 | True\n",
      "decoder.decoders.4.norm3.weight | [256] | 256 | True\n",
      "decoder.decoders.4.norm3.bias | [256] | 256 | True\n",
      "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n",
      "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n",
      "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n",
      "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n",
      "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n",
      "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
      "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n",
      "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
      "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
      "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
      "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
      "decoder.decoders.5.norm1.weight | [256] | 256 | True\n",
      "decoder.decoders.5.norm1.bias | [256] | 256 | True\n",
      "decoder.decoders.5.norm2.weight | [256] | 256 | True\n",
      "decoder.decoders.5.norm2.bias | [256] | 256 | True\n",
      "decoder.decoders.5.norm3.weight | [256] | 256 | True\n",
      "decoder.decoders.5.norm3.bias | [256] | 256 | True\n",
      "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n",
      "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n",
      "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n",
      "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n",
      "ctc.ctc_lo.bias | [4233] | 4233 | True\n",
      "Total parameters: 687.0, 49355282.0 elements.\n"
     ]
    }
   ],
   "source": [
    "conf_str='examples/aishell/s1/conf/conformer.yaml'\n",
    "cfg = CN().load_cfg(open(conf_str))\n",
    "cfg.model.input_dim = 80\n",
    "cfg.model.output_dim = 4233\n",
    "cfg.model.cmvn_file = \"/workspace/wenet/examples/aishell/s0/raw_wav/train/global_cmvn\"\n",
    "cfg.model.cmvn_file_type = 'json'\n",
    "cfg.freeze()\n",
    "\n",
    "model = U2Model(cfg.model)\n",
    "print_params(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "sapphire-agent",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder.global_cmvn.mean | [80] | 80\n",
      "encoder.global_cmvn.istd | [80] | 80\n",
      "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304\n",
      "encoder.embed.conv.0.bias | [256] | 256\n",
      "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824\n",
      "encoder.embed.conv.2.bias | [256] | 256\n",
      "encoder.embed.out.0.weight | [4864, 256] | 1245184\n",
      "encoder.embed.out.0.bias | [256] | 256\n",
      "encoder.after_norm.weight | [256] | 256\n",
      "encoder.after_norm.bias | [256] | 256\n",
      "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.0.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.0.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.0.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.0.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.0.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.0.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.0.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.0.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.0.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.0.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.0.norm_final.weight | [256] | 256\n",
      "encoder.encoders.0.norm_final.bias | [256] | 256\n",
      "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.0.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.1.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.1.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.1.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.1.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.1.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.1.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.1.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.1.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.1.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.1.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.1.norm_final.weight | [256] | 256\n",
      "encoder.encoders.1.norm_final.bias | [256] | 256\n",
      "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.1.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.2.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.2.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.2.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.2.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.2.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.2.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.2.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.2.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.2.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.2.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.2.norm_final.weight | [256] | 256\n",
      "encoder.encoders.2.norm_final.bias | [256] | 256\n",
      "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.2.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.3.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.3.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.3.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.3.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.3.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.3.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.3.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.3.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.3.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.3.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.3.norm_final.weight | [256] | 256\n",
      "encoder.encoders.3.norm_final.bias | [256] | 256\n",
      "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.3.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.4.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.4.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.4.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.4.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.4.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.4.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.4.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.4.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.4.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.4.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.4.norm_final.weight | [256] | 256\n",
      "encoder.encoders.4.norm_final.bias | [256] | 256\n",
      "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.4.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.5.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.5.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.5.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.5.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.5.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.5.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.5.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.5.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.5.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.5.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.5.norm_final.weight | [256] | 256\n",
      "encoder.encoders.5.norm_final.bias | [256] | 256\n",
      "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.5.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.6.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.6.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.6.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.6.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.6.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.6.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.6.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.6.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.6.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.6.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.6.norm_final.weight | [256] | 256\n",
      "encoder.encoders.6.norm_final.bias | [256] | 256\n",
      "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.6.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.7.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.7.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.7.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.7.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.7.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.7.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.7.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.7.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.7.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.7.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.7.norm_final.weight | [256] | 256\n",
      "encoder.encoders.7.norm_final.bias | [256] | 256\n",
      "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.7.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.8.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.8.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.8.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.8.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.8.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.8.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.8.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.8.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.8.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.8.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.8.norm_final.weight | [256] | 256\n",
      "encoder.encoders.8.norm_final.bias | [256] | 256\n",
      "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.8.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.9.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.9.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.9.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.9.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.9.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.9.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.9.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.9.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.9.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.9.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.9.norm_final.weight | [256] | 256\n",
      "encoder.encoders.9.norm_final.bias | [256] | 256\n",
      "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.9.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.10.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.10.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.10.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.10.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.10.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.10.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.10.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.10.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.10.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.10.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.10.norm_final.weight | [256] | 256\n",
      "encoder.encoders.10.norm_final.bias | [256] | 256\n",
      "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.10.concat_linear.bias | [256] | 256\n",
      "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256\n",
      "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256\n",
      "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256\n",
      "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256\n",
      "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256\n",
      "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256\n",
      "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536\n",
      "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
      "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
      "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256\n",
      "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
      "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512\n",
      "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
      "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256\n",
      "encoder.encoders.11.conv_module.norm.weight | [256] | 256\n",
      "encoder.encoders.11.conv_module.norm.bias | [256] | 256\n",
      "encoder.encoders.11.conv_module.norm._mean | [256] | 256\n",
      "encoder.encoders.11.conv_module.norm._variance | [256] | 256\n",
      "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
      "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256\n",
      "encoder.encoders.11.norm_ff.weight | [256] | 256\n",
      "encoder.encoders.11.norm_ff.bias | [256] | 256\n",
      "encoder.encoders.11.norm_mha.weight | [256] | 256\n",
      "encoder.encoders.11.norm_mha.bias | [256] | 256\n",
      "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256\n",
      "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256\n",
      "encoder.encoders.11.norm_conv.weight | [256] | 256\n",
      "encoder.encoders.11.norm_conv.bias | [256] | 256\n",
      "encoder.encoders.11.norm_final.weight | [256] | 256\n",
      "encoder.encoders.11.norm_final.bias | [256] | 256\n",
      "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072\n",
      "encoder.encoders.11.concat_linear.bias | [256] | 256\n",
      "decoder.embed.0.weight | [4233, 256] | 1083648\n",
      "decoder.after_norm.weight | [256] | 256\n",
      "decoder.after_norm.bias | [256] | 256\n",
      "decoder.output_layer.weight | [256, 4233] | 1083648\n",
      "decoder.output_layer.bias | [4233] | 4233\n",
      "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048\n",
      "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256\n",
      "decoder.decoders.0.norm1.weight | [256] | 256\n",
      "decoder.decoders.0.norm1.bias | [256] | 256\n",
      "decoder.decoders.0.norm2.weight | [256] | 256\n",
      "decoder.decoders.0.norm2.bias | [256] | 256\n",
      "decoder.decoders.0.norm3.weight | [256] | 256\n",
      "decoder.decoders.0.norm3.bias | [256] | 256\n",
      "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072\n",
      "decoder.decoders.0.concat_linear1.bias | [256] | 256\n",
      "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072\n",
      "decoder.decoders.0.concat_linear2.bias | [256] | 256\n",
      "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048\n",
      "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256\n",
      "decoder.decoders.1.norm1.weight | [256] | 256\n",
      "decoder.decoders.1.norm1.bias | [256] | 256\n",
      "decoder.decoders.1.norm2.weight | [256] | 256\n",
      "decoder.decoders.1.norm2.bias | [256] | 256\n",
      "decoder.decoders.1.norm3.weight | [256] | 256\n",
      "decoder.decoders.1.norm3.bias | [256] | 256\n",
      "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072\n",
      "decoder.decoders.1.concat_linear1.bias | [256] | 256\n",
      "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072\n",
      "decoder.decoders.1.concat_linear2.bias | [256] | 256\n",
      "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048\n",
      "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256\n",
      "decoder.decoders.2.norm1.weight | [256] | 256\n",
      "decoder.decoders.2.norm1.bias | [256] | 256\n",
      "decoder.decoders.2.norm2.weight | [256] | 256\n",
      "decoder.decoders.2.norm2.bias | [256] | 256\n",
      "decoder.decoders.2.norm3.weight | [256] | 256\n",
      "decoder.decoders.2.norm3.bias | [256] | 256\n",
      "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072\n",
      "decoder.decoders.2.concat_linear1.bias | [256] | 256\n",
      "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072\n",
      "decoder.decoders.2.concat_linear2.bias | [256] | 256\n",
      "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048\n",
      "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256\n",
      "decoder.decoders.3.norm1.weight | [256] | 256\n",
      "decoder.decoders.3.norm1.bias | [256] | 256\n",
      "decoder.decoders.3.norm2.weight | [256] | 256\n",
      "decoder.decoders.3.norm2.bias | [256] | 256\n",
      "decoder.decoders.3.norm3.weight | [256] | 256\n",
      "decoder.decoders.3.norm3.bias | [256] | 256\n",
      "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072\n",
      "decoder.decoders.3.concat_linear1.bias | [256] | 256\n",
      "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072\n",
      "decoder.decoders.3.concat_linear2.bias | [256] | 256\n",
      "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048\n",
      "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256\n",
      "decoder.decoders.4.norm1.weight | [256] | 256\n",
      "decoder.decoders.4.norm1.bias | [256] | 256\n",
      "decoder.decoders.4.norm2.weight | [256] | 256\n",
      "decoder.decoders.4.norm2.bias | [256] | 256\n",
      "decoder.decoders.4.norm3.weight | [256] | 256\n",
      "decoder.decoders.4.norm3.bias | [256] | 256\n",
      "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072\n",
      "decoder.decoders.4.concat_linear1.bias | [256] | 256\n",
      "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072\n",
      "decoder.decoders.4.concat_linear2.bias | [256] | 256\n",
      "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536\n",
      "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256\n",
      "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536\n",
      "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256\n",
      "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536\n",
      "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256\n",
      "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536\n",
      "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256\n",
      "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n",
      "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048\n",
      "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n",
      "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256\n",
      "decoder.decoders.5.norm1.weight | [256] | 256\n",
      "decoder.decoders.5.norm1.bias | [256] | 256\n",
      "decoder.decoders.5.norm2.weight | [256] | 256\n",
      "decoder.decoders.5.norm2.bias | [256] | 256\n",
      "decoder.decoders.5.norm3.weight | [256] | 256\n",
      "decoder.decoders.5.norm3.bias | [256] | 256\n",
      "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072\n",
      "decoder.decoders.5.concat_linear1.bias | [256] | 256\n",
      "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072\n",
      "decoder.decoders.5.concat_linear2.bias | [256] | 256\n",
      "ctc.ctc_lo.weight | [256, 4233] | 1083648\n",
      "ctc.ctc_lo.bias | [4233] | 4233\n",
      "Total parameters: 689, 49355442 elements.\n"
     ]
    }
   ],
   "source": [
    "summary(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ruled-invitation",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "U2Model(\n",
      "  (encoder): ConformerEncoder(\n",
      "    (global_cmvn): GlobalCMVN()\n",
      "    (embed): Conv2dSubsampling4(\n",
      "      (pos_enc): RelPositionalEncoding(\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "      )\n",
      "      (conv): Sequential(\n",
      "        (0): Conv2D(1, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n",
      "        (1): ReLU()\n",
      "        (2): Conv2D(256, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n",
      "        (3): ReLU()\n",
      "      )\n",
      "      (out): Sequential(\n",
      "        (0): Linear(in_features=4864, out_features=256, dtype=float32)\n",
      "      )\n",
      "    )\n",
      "    (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "    (encoders): LayerList(\n",
      "      (0): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (1): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (2): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (3): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (4): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (5): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (6): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (7): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (8): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (9): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (10): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (11): ConformerEncoderLayer(\n",
      "        (self_attn): RelPositionMultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "          (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (feed_forward_macaron): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): Swish()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (conv_module): ConvolutionModule(\n",
      "          (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
      "          (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
      "          (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
      "          (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
      "          (activation): Swish()\n",
      "        )\n",
      "        (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (decoder): TransformerDecoder(\n",
      "    (embed): Sequential(\n",
      "      (0): Embedding(4233, 256, sparse=False)\n",
      "      (1): PositionalEncoding(\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "      )\n",
      "    )\n",
      "    (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "    (output_layer): Linear(in_features=256, out_features=4233, dtype=float32)\n",
      "    (decoders): LayerList(\n",
      "      (0): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (1): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (2): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (3): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (4): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "      (5): DecoderLayer(\n",
      "        (self_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (src_attn): MultiHeadedAttention(\n",
      "          (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
      "          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
      "        )\n",
      "        (feed_forward): PositionwiseFeedForward(\n",
      "          (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
      "          (activation): ReLU()\n",
      "          (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "          (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
      "        )\n",
      "        (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
      "        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
      "        (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "        (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (ctc): CTCDecoder(\n",
      "    (ctc_lo): Linear(in_features=256, out_features=4233, dtype=float32)\n",
      "    (criterion): CTCLoss(\n",
      "      (loss): CTCLoss()\n",
      "    )\n",
      "  )\n",
      "  (criterion_att): LabelSmoothingLoss(\n",
      "    (criterion): KLDivLoss()\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fossil-means",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load feat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fleet-despite",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "compute_cmvn_loader_test.ipynb         encoder.npz\r\n",
      "dataloader.ipynb                       hack_api_test.ipynb\r\n",
      "dataloader_with_tokens_tokenids.ipynb  jit_infer.ipynb\r\n",
      "data.npz                               layer_norm_test.ipynb\r\n",
      "decoder.npz                            Linear_test.ipynb\r\n",
      "enc_0_ff_out.npz                       mask_and_masked_fill_test.ipynb\r\n",
      "enc_0_norm_ff.npz                      model.npz\r\n",
      "enc_0.npz                              position_embeding_check.ipynb\r\n",
      "enc_0_selattn_out.npz                  python_test.ipynb\r\n",
      "enc_2.npz                              train_test.ipynb\r\n",
      "enc_all.npz                            u2_model.ipynb\r\n",
      "enc_embed.npz\r\n"
     ]
    }
   ],
   "source": [
    "%ls .notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "abroad-oracle",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['BAC009S0739W0246' 'BAC009S0727W0424' 'BAC009S0753W0412'\n",
      " 'BAC009S0756W0206' 'BAC009S0740W0414' 'BAC009S0728W0426'\n",
      " 'BAC009S0739W0214' 'BAC009S0753W0423' 'BAC009S0734W0201'\n",
      " 'BAC009S0740W0427' 'BAC009S0730W0423' 'BAC009S0728W0367'\n",
      " 'BAC009S0730W0418' 'BAC009S0727W0157' 'BAC009S0749W0409'\n",
      " 'BAC009S0727W0418']\n",
      "(16, 207, 80)\n",
      "[[[ 8.994624   9.538309   9.191589  ... 10.507416   9.563305   8.256403 ]\n",
      "  [ 9.798841  10.405224   9.26511   ... 10.251211   9.543982   8.873768 ]\n",
      "  [10.6890745 10.395469   8.053548  ...  9.906749  10.064903   8.050915 ]\n",
      "  ...\n",
      "  [ 9.217986   9.65069    8.505259  ...  9.687183   8.742463   7.9865475]\n",
      "  [10.129122   9.935194   9.37982   ...  9.563894   9.825992   8.979543 ]\n",
      "  [ 9.095531   7.1338377  9.468001  ...  9.472748   9.021235   7.447914 ]]\n",
      "\n",
      " [[11.430976  10.671858   6.0841026 ...  9.382682   8.729745   7.5315614]\n",
      "  [ 9.731717   7.8104815  7.5714607 ... 10.043035   9.243595   7.3540792]\n",
      "  [10.65017   10.600604   8.467784  ...  9.281448   9.186885   8.070343 ]\n",
      "  ...\n",
      "  [ 9.096987   9.2637     8.075275  ...  8.431845   8.370505   8.002926 ]\n",
      "  [10.461651  10.147784   6.7693496 ...  9.779426   9.577453   8.080652 ]\n",
      "  [ 7.794432   5.621059   7.9750648 ...  9.997245   9.849678   8.031287 ]]\n",
      "\n",
      " [[ 7.3455667  7.896357   7.5795946 ... 11.631024  10.451254   9.123633 ]\n",
      "  [ 8.628678   8.4630575  7.499242  ... 12.415986  10.975749   8.9425745]\n",
      "  [ 9.831394  10.2812805  8.97241   ... 12.1386795 10.40175    9.005517 ]\n",
      "  ...\n",
      "  [ 7.089641   7.405548   6.8142557 ...  9.325196   9.273162   8.353427 ]\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]]\n",
      "\n",
      " ...\n",
      "\n",
      " [[10.933237  10.464394   7.7202725 ... 10.348816   9.302338   7.1553144]\n",
      "  [10.449866   9.907033   9.029272  ...  9.952465   9.414051   7.559279 ]\n",
      "  [10.487655   9.81259    9.895244  ...  9.58662    9.341254   7.7849016]\n",
      "  ...\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]]\n",
      "\n",
      " [[ 9.944384   9.585867   8.220328  ... 11.588647  11.045029   8.817075 ]\n",
      "  [ 7.678356   8.322397   7.533047  ... 11.055085  10.535685   9.27465  ]\n",
      "  [ 8.626197   9.675917   9.841045  ... 11.378827  10.922112   8.991444 ]\n",
      "  ...\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]]\n",
      "\n",
      " [[ 8.107938   7.759043   6.710301  ... 12.650573  11.466156  11.061517 ]\n",
      "  [11.380332  11.222007   8.658889  ... 12.810616  12.222216  11.689288 ]\n",
      "  [10.677676   9.920579   8.046089  ... 13.572894  12.5624075 11.155033 ]\n",
      "  ...\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]\n",
      "  [ 0.         0.         0.        ...  0.         0.         0.       ]]]\n",
      "[207 207 205 205 203 203 198 197 195 188 186 186 185 180 166 163]\n",
      "[[2995 3116 1209  565   -1   -1]\n",
      " [ 236 1176  331   66 3925 4077]\n",
      " [2693  524  234 1145  366   -1]\n",
      " [3875 4211 3062  700   -1   -1]\n",
      " [ 272  987 1134  494 2959   -1]\n",
      " [1936 3715  120 2553 2695 2710]\n",
      " [  25 1149 3930   -1   -1   -1]\n",
      " [1753 1778 1237  482 3925  110]\n",
      " [3703    2  565 3827   -1   -1]\n",
      " [1150 2734   10 2478 3490   -1]\n",
      " [ 426  811   95  489  144   -1]\n",
      " [2313 2006  489  975   -1   -1]\n",
      " [3702 3414  205 1488 2966 1347]\n",
      " [  70 1741  702 1666   -1   -1]\n",
      " [ 703 1778 1030  849   -1   -1]\n",
      " [ 814 1674  115 3827   -1   -1]]\n",
      "[4 6 5 4 5 6 3 6 4 5 5 4 6 4 4 4]\n"
     ]
    }
   ],
   "source": [
    "data = np.load('.notebook/data.npz', allow_pickle=True)\n",
    "keys=data['keys']\n",
    "feat=data['feat']\n",
    "feat_len=data['feat_len']\n",
    "text=data['text']\n",
    "text_len=data['text_len']\n",
    "print(keys)\n",
    "print(feat.shape)\n",
    "print(feat)\n",
    "print(feat_len)\n",
    "print(text)\n",
    "print(text_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "false-instrument",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "arctic-proxy",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n",
    "# torch.Size([16, 207, 80])\n",
    "# tensor([[[ 8.9946,  9.5383,  9.1916,  ..., 10.5074,  9.5633,  8.2564],\n",
    "#          [ 9.7988, 10.4052,  9.2651,  ..., 10.2512,  9.5440,  8.8738],\n",
    "#          [10.6891, 10.3955,  8.0535,  ...,  9.9067, 10.0649,  8.0509],\n",
    "#          ...,\n",
    "#          [ 9.2180,  9.6507,  8.5053,  ...,  9.6872,  8.7425,  7.9865],\n",
    "#          [10.1291,  9.9352,  9.3798,  ...,  9.5639,  9.8260,  8.9795],\n",
    "#          [ 9.0955,  7.1338,  9.4680,  ...,  9.4727,  9.0212,  7.4479]],\n",
    "\n",
    "#         [[11.4310, 10.6719,  6.0841,  ...,  9.3827,  8.7297,  7.5316],\n",
    "#          [ 9.7317,  7.8105,  7.5715,  ..., 10.0430,  9.2436,  7.3541],\n",
    "#          [10.6502, 10.6006,  8.4678,  ...,  9.2814,  9.1869,  8.0703],\n",
    "#          ...,\n",
    "#          [ 9.0970,  9.2637,  8.0753,  ...,  8.4318,  8.3705,  8.0029],\n",
    "#          [10.4617, 10.1478,  6.7693,  ...,  9.7794,  9.5775,  8.0807],\n",
    "#          [ 7.7944,  5.6211,  7.9751,  ...,  9.9972,  9.8497,  8.0313]],\n",
    "\n",
    "#         [[ 7.3456,  7.8964,  7.5796,  ..., 11.6310, 10.4513,  9.1236],\n",
    "#          [ 8.6287,  8.4631,  7.4992,  ..., 12.4160, 10.9757,  8.9426],\n",
    "#          [ 9.8314, 10.2813,  8.9724,  ..., 12.1387, 10.4017,  9.0055],\n",
    "#          ...,\n",
    "#          [ 7.0896,  7.4055,  6.8143,  ...,  9.3252,  9.2732,  8.3534],\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
    "\n",
    "#         ...,\n",
    "\n",
    "#         [[10.9332, 10.4644,  7.7203,  ..., 10.3488,  9.3023,  7.1553],\n",
    "#          [10.4499,  9.9070,  9.0293,  ...,  9.9525,  9.4141,  7.5593],\n",
    "#          [10.4877,  9.8126,  9.8952,  ...,  9.5866,  9.3413,  7.7849],\n",
    "#          ...,\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
    "\n",
    "#         [[ 9.9444,  9.5859,  8.2203,  ..., 11.5886, 11.0450,  8.8171],\n",
    "#          [ 7.6784,  8.3224,  7.5330,  ..., 11.0551, 10.5357,  9.2746],\n",
    "#          [ 8.6262,  9.6759,  9.8410,  ..., 11.3788, 10.9221,  8.9914],\n",
    "#          ...,\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
    "\n",
    "#         [[ 8.1079,  7.7590,  6.7103,  ..., 12.6506, 11.4662, 11.0615],\n",
    "#          [11.3803, 11.2220,  8.6589,  ..., 12.8106, 12.2222, 11.6893],\n",
    "#          [10.6777,  9.9206,  8.0461,  ..., 13.5729, 12.5624, 11.1550],\n",
    "#          ...,\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
    "#          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])\n",
    "# tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n",
    "#         166, 163], dtype=torch.int32)\n",
    "# tensor([[2995, 3116, 1209,  565,   -1,   -1],\n",
    "#         [ 236, 1176,  331,   66, 3925, 4077],\n",
    "#         [2693,  524,  234, 1145,  366,   -1],\n",
    "#         [3875, 4211, 3062,  700,   -1,   -1],\n",
    "#         [ 272,  987, 1134,  494, 2959,   -1],\n",
    "#         [1936, 3715,  120, 2553, 2695, 2710],\n",
    "#         [  25, 1149, 3930,   -1,   -1,   -1],\n",
    "#         [1753, 1778, 1237,  482, 3925,  110],\n",
    "#         [3703,    2,  565, 3827,   -1,   -1],\n",
    "#         [1150, 2734,   10, 2478, 3490,   -1],\n",
    "#         [ 426,  811,   95,  489,  144,   -1],\n",
    "#         [2313, 2006,  489,  975,   -1,   -1],\n",
    "#         [3702, 3414,  205, 1488, 2966, 1347],\n",
    "#         [  70, 1741,  702, 1666,   -1,   -1],\n",
    "#         [ 703, 1778, 1030,  849,   -1,   -1],\n",
    "#         [ 814, 1674,  115, 3827,   -1,   -1]], dtype=torch.int32)\n",
    "# tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "seasonal-switch",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "defined-brooks",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "compute_cmvn_loader_test.ipynb\t       encoder.npz\r\n",
      "dataloader.ipynb\t\t       hack_api_test.ipynb\r\n",
      "dataloader_with_tokens_tokenids.ipynb  jit_infer.ipynb\r\n",
      "data.npz\t\t\t       layer_norm_test.ipynb\r\n",
      "decoder.npz\t\t\t       Linear_test.ipynb\r\n",
      "enc_0_ff_out.npz\t\t       mask_and_masked_fill_test.ipynb\r\n",
      "enc_0_norm_ff.npz\t\t       model.npz\r\n",
      "enc_0.npz\t\t\t       position_embeding_check.ipynb\r\n",
      "enc_0_selattn_out.npz\t\t       python_test.ipynb\r\n",
      "enc_2.npz\t\t\t       train_test.ipynb\r\n",
      "enc_all.npz\t\t\t       u2_model.ipynb\r\n",
      "enc_embed.npz\r\n"
     ]
    }
   ],
   "source": [
    "# load model param\n",
    "!ls .notebook\n",
    "data = np.load('.notebook/model.npz', allow_pickle=True)\n",
    "state_dict = data['state'].item()\n",
    "\n",
    "for key, _ in model.state_dict().items():\n",
    "    if key not in state_dict:\n",
    "        print(f\"{key} not find.\")\n",
    "\n",
    "model.set_state_dict(state_dict)\n",
    "\n",
    "now_state_dict = model.state_dict()\n",
    "for key, value in now_state_dict.items():\n",
    "    if not np.allclose(value.numpy(), state_dict[key]):\n",
    "        print(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "exempt-viewer",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "confident-piano",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/framework.py:687: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  elif dtype == np.bool:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [142.48880005]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [41.84146118]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [377.33258057])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:238: UserWarning: The dtype of left and right variables are not the same, left dtype is VarType.FP32, but right dtype is VarType.INT32, the right dtype will convert to VarType.FP32\n",
      "  format(lhs_dtype, rhs_dtype, lhs_dtype))\n"
     ]
    }
   ],
   "source": [
    "# compute loss\n",
    "import paddle\n",
    "feat=paddle.to_tensor(feat)\n",
    "feat_len=paddle.to_tensor(feat_len, dtype='int64')\n",
    "text=paddle.to_tensor(text, dtype='int64')\n",
    "text_len=paddle.to_tensor(text_len, dtype='int64')\n",
    "\n",
    "model.eval()\n",
    "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n",
    "                                         text, text_len)\n",
    "print(total_loss, attention_loss, ctc_loss )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "better-senator",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tensor(142.4888, device='cuda:0', grad_fn=<AddBackward0>) \n",
    "# tensor(41.8415, device='cuda:0', grad_fn=<DivBackward0>) \n",
    "# tensor(377.3326, device='cuda:0', grad_fn=<DivBackward0>)\n",
    "# 142.4888 41.84146 377.33258"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "related-banking",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "olympic-problem",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[16, 51, 256]\n",
      "[16, 1, 51]\n",
      "Tensor(shape=[51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [[-0.70194179,  0.56254166,  0.68803459, ...,  1.12373221,  0.78039235,  1.13693869],\n",
      "        [-0.77877808,  0.39126658,  0.71887815, ...,  1.25188220,  0.88616788,  1.31734526],\n",
      "        [-0.95908946,  0.63460249,  0.87671334, ...,  0.98183727,  0.74401081,  1.29032660],\n",
      "        ...,\n",
      "        [-1.07322502,  0.67236906,  0.92303109, ...,  0.90754563,  0.81767166,  1.32396567],\n",
      "        [-1.16541159,  0.68199694,  0.69394493, ...,  1.22383487,  0.80282891,  1.45065081],\n",
      "        [-1.27320945,  0.71458030,  0.75819558, ...,  0.94154912,  0.87748396,  1.26230514]])\n"
     ]
    }
   ],
   "source": [
    "# ecnoder\n",
    "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n",
    "print(encoder_out.shape)\n",
    "print(encoder_mask.shape)\n",
    "print(encoder_out[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "shaped-alaska",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "deepspeech  examples  README_cn.md\tsetup.sh     tools\r\n",
      "docs\t    LICENSE   README.md\t\ttests\t     utils\r\n",
      "env.sh\t    log       requirements.txt\tthird_party\r\n"
     ]
    }
   ],
   "source": [
    "!ls\n",
    "data = np.load('.notebook/encoder.npz', allow_pickle=True)\n",
    "torch_mask = data['mask']\n",
    "torch_encoder_out = data['out']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "federal-rover",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "print(np.testing.assert_equal(torch_mask, encoder_mask.numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "regulated-interstate",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "[[-0.7019424   0.56254166  0.6880345  ...  1.1237322   0.78039217\n",
      "   1.1369387 ]\n",
      " [-0.778778    0.39126638  0.7188779  ...  1.2518823   0.8861681\n",
      "   1.3173454 ]\n",
      " [-0.9590891   0.6346026   0.87671363 ...  0.9818373   0.74401116\n",
      "   1.2903274 ]\n",
      " ...\n",
      " [-1.0732253   0.6723689   0.9230311  ...  0.9075457   0.8176713\n",
      "   1.3239657 ]\n",
      " [-1.165412    0.6819976   0.69394535 ...  1.2238353   0.80282927\n",
      "   1.4506509 ]\n",
      " [-1.2732087   0.71458083  0.7581961  ...  0.9415482   0.877484\n",
      "   1.2623053 ]]\n",
      "----\n",
      "[[-0.7019418   0.56254166  0.6880346  ...  1.1237322   0.78039235\n",
      "   1.1369387 ]\n",
      " [-0.7787781   0.39126658  0.71887815 ...  1.2518822   0.8861679\n",
      "   1.3173453 ]\n",
      " [-0.95908946  0.6346025   0.87671334 ...  0.9818373   0.7440108\n",
      "   1.2903266 ]\n",
      " ...\n",
      " [-1.073225    0.67236906  0.9230311  ...  0.9075456   0.81767166\n",
      "   1.3239657 ]\n",
      " [-1.1654116   0.68199694  0.69394493 ...  1.2238349   0.8028289\n",
      "   1.4506508 ]\n",
      " [-1.2732095   0.7145803   0.7581956  ...  0.9415491   0.87748396\n",
      "   1.2623051 ]]\n",
      "True\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n",
    "print(torch_encoder_out[0])\n",
    "print(\"----\")\n",
    "print(encoder_out.numpy()[0])\n",
    "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5, rtol=1e-6))\n",
    "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6, rtol=1e-6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "proof-scheduling",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [377.33258057])\n",
      "[1.]\n",
      "[[ 3.16902876e+00 -1.51763987e-02  4.91095744e-02 ... -2.47971853e-03\n",
      "  -5.93360700e-03 -7.26609165e-03]\n",
      " [-1.74184477e+00  7.75874173e-03 -4.49434854e-02 ...  9.92412097e-04\n",
      "   2.46337592e-03  2.31892057e-03]\n",
      " [-2.33343339e+00  1.30475955e-02 -2.66557075e-02 ...  2.27532350e-03\n",
      "   5.76924905e-03  7.48788286e-03]\n",
      " ...\n",
      " [-4.30358458e+00  2.46054661e-02 -9.00950655e-02 ...  4.43156436e-03\n",
      "   1.16122244e-02  1.44715561e-02]\n",
      " [-3.36921120e+00  1.73153952e-02 -6.36872873e-02 ...  3.28363618e-03\n",
      "   8.58010259e-03  1.07794888e-02]\n",
      " [-6.62045336e+00  3.49955931e-02 -1.23962618e-01 ...  6.36671018e-03\n",
      "   1.60814095e-02  2.03891303e-02]]\n",
      "[-4.3777819e+00  2.3245810e-02 -9.3339294e-02 ...  4.2569344e-03\n",
      "  1.0919910e-02  1.3787797e-02]\n"
     ]
    }
   ],
   "source": [
    "from paddle.nn import functional as F\n",
    "def ctc_loss(logits,\n",
    "             labels,\n",
    "             input_lengths,\n",
    "             label_lengths,\n",
    "             blank=0,\n",
    "             reduction='mean',\n",
    "             norm_by_times=False):\n",
    "    loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,\n",
    "                                           input_lengths, label_lengths)\n",
    "    loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])\n",
    "    assert reduction in ['mean', 'sum', 'none']\n",
    "    if reduction == 'mean':\n",
    "        loss_out = paddle.mean(loss_out / label_lengths)\n",
    "    elif reduction == 'sum':\n",
    "        loss_out = paddle.sum(loss_out)\n",
    "    return loss_out\n",
    "\n",
    "F.ctc_loss = ctc_loss\n",
    "\n",
    "torch_mask_t = paddle.to_tensor(torch_mask, dtype='int64')\n",
    "encoder_out_lens = torch_mask_t.squeeze(1).sum(1)\n",
    "loss_ctc = model.ctc(paddle.to_tensor(torch_encoder_out), encoder_out_lens, text, text_len)\n",
    "print(loss_ctc)\n",
    "loss_ctc.backward()\n",
    "print(loss_ctc.grad)\n",
    "print(model.ctc.ctc_lo.weight.grad)\n",
    "print(model.ctc.ctc_lo.bias.grad)\n",
    "\n",
    "\n",
    "# tensor(377.3326, device='cuda:0', grad_fn=<DivBackward0>)\n",
    "# None\n",
    "# [[ 3.16902351e+00 -1.51765049e-02  4.91097234e-02 ... -2.47973716e-03\n",
    "#   -5.93366381e-03 -7.26613170e-03]\n",
    "#  [-1.74185038e+00  7.75875803e-03 -4.49435972e-02 ...  9.92415240e-04\n",
    "#    2.46338220e-03  2.31891591e-03]\n",
    "#  [-2.33343077e+00  1.30476682e-02 -2.66557615e-02 ...  2.27533933e-03\n",
    "#    5.76929189e-03  7.48792710e-03]\n",
    "#  ...\n",
    "#  [-4.30356789e+00  2.46056803e-02 -9.00955945e-02 ...  4.43160534e-03\n",
    "#    1.16123557e-02  1.44716976e-02]\n",
    "#  [-3.36919212e+00  1.73155665e-02 -6.36875406e-02 ...  3.28367390e-03\n",
    "#    8.58021621e-03  1.07796099e-02]\n",
    "#  [-6.62039661e+00  3.49958315e-02 -1.23963736e-01 ...  6.36674836e-03\n",
    "#    1.60815325e-02  2.03892551e-02]]\n",
    "# [-4.3777566e+00  2.3245990e-02 -9.3339972e-02 ...  4.2569702e-03\n",
    "#   1.0920014e-02  1.3787906e-02]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "enclosed-consolidation",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "synthetic-hungarian",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [41.84146118]) 0.0\n"
     ]
    }
   ],
   "source": [
    "loss_att, acc_att = model._calc_att_loss(paddle.to_tensor(torch_encoder_out), paddle.to_tensor(torch_mask),\n",
    "                                                    text, text_len)\n",
    "print(loss_att, acc_att)\n",
    "#tensor(41.8416, device='cuda:0', grad_fn=<DivBackward0>) 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "indian-sweden",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "id": "marine-cuisine",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-3.7638968e-01 -8.2272053e-01  7.4276292e-01 ...  3.4200522e-01\n",
      "   1.5034772e-02  4.0337229e-01]\n",
      " [-8.7386459e-01 -3.1389427e-01  4.1987866e-01 ...  3.7723729e-01\n",
      "  -1.4352810e-01 -1.0023664e+00]\n",
      " [-4.3505096e-01  3.4504786e-02 -2.8710306e-01 ...  7.7274129e-02\n",
      "  -1.1672243e+00 -2.6848501e-01]\n",
      " ...\n",
      " [ 4.2471480e-01  5.8885634e-01  2.0203922e-02 ...  3.7405500e-01\n",
      "   4.5470044e-02 -3.7139410e-01]\n",
      " [-3.7978446e-01 -8.1084180e-01  7.5725085e-01 ...  2.6038891e-01\n",
      "  -7.9347193e-04  4.2537671e-01]\n",
      " [-3.8279903e-01 -8.1206715e-01  7.4943429e-01 ...  2.6173013e-01\n",
      "  -1.0499060e-03  4.2678756e-01]]\n"
     ]
    }
   ],
   "source": [
    "data = np.load(\".notebook/decoder.npz\", allow_pickle=True)\n",
    "torch_decoder_out = data['decoder_out']\n",
    "print(torch_decoder_out[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "id": "several-result",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,\n",
    "                ignore_id: int):\n",
    "    \"\"\"Add <sos> and <eos> labels.\n",
    "    Args:\n",
    "        ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)\n",
    "        sos (int): index of <sos>\n",
    "        eos (int): index of <eeos>\n",
    "        ignore_id (int): index of padding\n",
    "    Returns:\n",
    "        ys_in (paddle.Tensor) : (B, Lmax + 1)\n",
    "        ys_out (paddle.Tensor) : (B, Lmax + 1)\n",
    "    Examples:\n",
    "        >>> sos_id = 10\n",
    "        >>> eos_id = 11\n",
    "        >>> ignore_id = -1\n",
    "        >>> ys_pad\n",
    "        tensor([[ 1,  2,  3,  4,  5],\n",
    "                [ 4,  5,  6, -1, -1],\n",
    "                [ 7,  8,  9, -1, -1]], dtype=paddle.int32)\n",
    "        >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)\n",
    "        >>> ys_in\n",
    "        tensor([[10,  1,  2,  3,  4,  5],\n",
    "                [10,  4,  5,  6, 11, 11],\n",
    "                [10,  7,  8,  9, 11, 11]])\n",
    "        >>> ys_out\n",
    "        tensor([[ 1,  2,  3,  4,  5, 11],\n",
    "                [ 4,  5,  6, 11, -1, -1],\n",
    "                [ 7,  8,  9, 11, -1, -1]])\n",
    "    \"\"\"\n",
    "    # TODO(Hui Zhang): using comment code, \n",
    "    #_sos = paddle.to_tensor(\n",
    "    #    [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n",
    "    #_eos = paddle.to_tensor(\n",
    "    #    [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n",
    "    #ys = [y[y != ignore_id] for y in ys_pad]  # parse padded ys\n",
    "    #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]\n",
    "    #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]\n",
    "    #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)\n",
    "    B = ys_pad.size(0)\n",
    "    _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos\n",
    "    _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos\n",
    "    ys_in = paddle.cat([_sos, ys_pad], dim=1)\n",
    "    mask_pad = (ys_in == ignore_id)\n",
    "    ys_in = ys_in.masked_fill(mask_pad, eos)\n",
    "    \n",
    "\n",
    "    ys_out = paddle.cat([ys_pad, _eos], dim=1)\n",
    "    ys_out = ys_out.masked_fill(mask_pad, eos)\n",
    "    mask_eos = (ys_out == ignore_id)\n",
    "    ys_out = ys_out.masked_fill(mask_eos, eos)\n",
    "    ys_out = ys_out.masked_fill(mask_pad, ignore_id)\n",
    "    return ys_in, ys_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "id": "possible-bulgaria",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[4232, 2995, 3116, 1209, 565 , 4232, 4232],\n",
      "        [4232, 236 , 1176, 331 , 66  , 3925, 4077],\n",
      "        [4232, 2693, 524 , 234 , 1145, 366 , 4232],\n",
      "        [4232, 3875, 4211, 3062, 700 , 4232, 4232],\n",
      "        [4232, 272 , 987 , 1134, 494 , 2959, 4232],\n",
      "        [4232, 1936, 3715, 120 , 2553, 2695, 2710],\n",
      "        [4232, 25  , 1149, 3930, 4232, 4232, 4232],\n",
      "        [4232, 1753, 1778, 1237, 482 , 3925, 110 ],\n",
      "        [4232, 3703, 2   , 565 , 3827, 4232, 4232],\n",
      "        [4232, 1150, 2734, 10  , 2478, 3490, 4232],\n",
      "        [4232, 426 , 811 , 95  , 489 , 144 , 4232],\n",
      "        [4232, 2313, 2006, 489 , 975 , 4232, 4232],\n",
      "        [4232, 3702, 3414, 205 , 1488, 2966, 1347],\n",
      "        [4232, 70  , 1741, 702 , 1666, 4232, 4232],\n",
      "        [4232, 703 , 1778, 1030, 849 , 4232, 4232],\n",
      "        [4232, 814 , 1674, 115 , 3827, 4232, 4232]])\n",
      "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[2995, 3116, 1209,  565, 4232, -1  , -1  ],\n",
      "        [ 236, 1176,  331,  66 , 3925, 4077, 4232],\n",
      "        [2693,  524,  234, 1145,  366, 4232, -1  ],\n",
      "        [3875, 4211, 3062,  700, 4232, -1  , -1  ],\n",
      "        [ 272,  987, 1134,  494, 2959, 4232, -1  ],\n",
      "        [1936, 3715,  120, 2553, 2695, 2710, 4232],\n",
      "        [ 25 , 1149, 3930, 4232, -1  , -1  , -1  ],\n",
      "        [1753, 1778, 1237,  482, 3925,  110, 4232],\n",
      "        [3703,  2  ,  565, 3827, 4232, -1  , -1  ],\n",
      "        [1150, 2734,  10 , 2478, 3490, 4232, -1  ],\n",
      "        [ 426,  811,  95 ,  489,  144, 4232, -1  ],\n",
      "        [2313, 2006,  489,  975, 4232, -1  , -1  ],\n",
      "        [3702, 3414,  205, 1488, 2966, 1347, 4232],\n",
      "        [ 70 , 1741,  702, 1666, 4232, -1  , -1  ],\n",
      "        [ 703, 1778, 1030,  849, 4232, -1  , -1  ],\n",
      "        [ 814, 1674,  115, 3827, 4232, -1  , -1  ]])\n"
     ]
    }
   ],
   "source": [
    "ys_pad = text\n",
    "ys_pad_lens = text_len\n",
    "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n",
    "                                model.ignore_id)\n",
    "ys_in_lens = ys_pad_lens + 1\n",
    "print(ys_in_pad)\n",
    "print(ys_out_pad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 285,
   "id": "north-walter",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "True\n",
      "False\n",
      "[[-3.76389682e-01 -8.22720408e-01  7.42762923e-01 ...  3.42005253e-01\n",
      "   1.50350705e-02  4.03372347e-01]\n",
      " [-8.73864174e-01 -3.13894272e-01  4.19878662e-01 ...  3.77237231e-01\n",
      "  -1.43528014e-01 -1.00236630e+00]\n",
      " [-4.35050905e-01  3.45046446e-02 -2.87102997e-01 ...  7.72742853e-02\n",
      "  -1.16722476e+00 -2.68485069e-01]\n",
      " ...\n",
      " [ 4.24714804e-01  5.88856399e-01  2.02039629e-02 ...  3.74054879e-01\n",
      "   4.54700664e-02 -3.71394157e-01]\n",
      " [-3.79784584e-01 -8.10841978e-01  7.57250786e-01 ...  2.60389000e-01\n",
      "  -7.93404877e-04  4.25376773e-01]\n",
      " [-3.82798851e-01 -8.12067091e-01  7.49434292e-01 ...  2.61730075e-01\n",
      "  -1.04988366e-03  4.26787734e-01]]\n",
      "---\n",
      "[[-3.7638968e-01 -8.2272053e-01  7.4276292e-01 ...  3.4200522e-01\n",
      "   1.5034772e-02  4.0337229e-01]\n",
      " [-8.7386459e-01 -3.1389427e-01  4.1987866e-01 ...  3.7723729e-01\n",
      "  -1.4352810e-01 -1.0023664e+00]\n",
      " [-4.3505096e-01  3.4504786e-02 -2.8710306e-01 ...  7.7274129e-02\n",
      "  -1.1672243e+00 -2.6848501e-01]\n",
      " ...\n",
      " [ 4.2471480e-01  5.8885634e-01  2.0203922e-02 ...  3.7405500e-01\n",
      "   4.5470044e-02 -3.7139410e-01]\n",
      " [-3.7978446e-01 -8.1084180e-01  7.5725085e-01 ...  2.6038891e-01\n",
      "  -7.9347193e-04  4.2537671e-01]\n",
      " [-3.8279903e-01 -8.1206715e-01  7.4943429e-01 ...  2.6173013e-01\n",
      "  -1.0499060e-03  4.2678756e-01]]\n"
     ]
    }
   ],
   "source": [
    "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n",
    "                                      ys_in_lens)\n",
    "\n",
    "print(np.allclose(decoder_out.numpy(), torch_decoder_out))\n",
    "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-6))\n",
    "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-7))\n",
    "print(decoder_out.numpy()[0])\n",
    "print('---')\n",
    "print(torch_decoder_out[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "armed-cowboy",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fifty-earth",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "proud-commonwealth",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "id": "assisted-fortune",
   "metadata": {},
   "outputs": [],
   "source": [
    "from paddle import nn\n",
    "import paddle\n",
    "from paddle.nn import functional as F\n",
    "\n",
    "class LabelSmoothingLoss(nn.Layer):\n",
    "\n",
    "    def __init__(self,\n",
    "                 size: int,\n",
    "                 padding_idx: int,\n",
    "                 smoothing: float,\n",
    "                 normalize_length: bool=False):\n",
    "        super().__init__()\n",
    "        self.size = size\n",
    "        self.padding_idx = padding_idx\n",
    "        self.smoothing = smoothing\n",
    "        self.confidence = 1.0 - smoothing\n",
    "        self.normalize_length = normalize_length\n",
    "        self.criterion = nn.KLDivLoss(reduction=\"none\")\n",
    "\n",
    "    def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:\n",
    "        \"\"\"Compute loss between x and target.\n",
    "        The model outputs and data labels tensors are flatten to\n",
    "        (batch*seqlen, class) shape and a mask is applied to the\n",
    "        padding part which should not be calculated for loss.\n",
    "        \n",
    "        Args:\n",
    "            x (paddle.Tensor): prediction (batch, seqlen, class)\n",
    "            target (paddle.Tensor):\n",
    "                target signal masked with self.padding_id (batch, seqlen)\n",
    "        Returns:\n",
    "            loss (paddle.Tensor) : The KL loss, scalar float value\n",
    "        \"\"\"\n",
    "        B, T, D = paddle.shape(x)\n",
    "        assert D == self.size\n",
    "        x = x.reshape((-1, self.size))\n",
    "        target = target.reshape([-1])\n",
    "\n",
    "        # use zeros_like instead of torch.no_grad() for true_dist,\n",
    "        # since no_grad() can not be exported by JIT\n",
    "        true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))\n",
    "        ignore = target == self.padding_idx  # (B,)\n",
    "        print(self.smoothing / (self.size - 1))\n",
    "        print(true_dist)\n",
    "\n",
    "        #target = target * (1 - ignore)  # avoid -1 index\n",
    "        target = target.masked_fill(ignore, 0)  # avoid -1 index\n",
    "        \n",
    "        \n",
    "        #true_dist += F.one_hot(target, self.size) * self.confidence\n",
    "        target_mask = F.one_hot(target, self.size)\n",
    "        true_dist *= (1 - target_mask)\n",
    "        true_dist += target_mask * self.confidence\n",
    "        \n",
    "\n",
    "        kl = self.criterion(F.log_softmax(x, axis=1), true_dist)\n",
    "        \n",
    "        #TODO(Hui Zhang): sum not support bool type\n",
    "        #total = len(target) - int(ignore.sum())\n",
    "        total = len(target) - int(ignore.type_as(target).sum())\n",
    "        denom = total if self.normalize_length else B\n",
    "\n",
    "        #numer = (kl * (1 - ignore)).sum()\n",
    "        numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n",
    "        return numer / denom\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "id": "weighted-delight",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3629489603024576e-05\n",
      "Tensor(shape=[112, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
      "        [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
      "        [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
      "        ...,\n",
      "        [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
      "        [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
      "        [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363]])\n",
      "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [41.84146118])\n",
      "VarType.INT64\n"
     ]
    }
   ],
   "source": [
    "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n",
    "loss_att = criteron(paddle.to_tensor(torch_decoder_out), ys_out_pad.astype('int64'))\n",
    "print(loss_att)\n",
    "print(ys_out_pad.dtype)\n",
    "# tensor(41.8416, device='cuda:0', grad_fn=<DivBackward0>)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 286,
   "id": "dress-shelter",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [41.84146118])\n",
      "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [41.84146118])\n",
      "4233\n",
      "-1\n",
      "0.1\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n",
    "                                      ys_in_lens)\n",
    "\n",
    "loss_att = model.criterion_att(paddle.to_tensor(torch_decoder_out), ys_out_pad)\n",
    "print(loss_att)\n",
    "\n",
    "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n",
    "print(loss_att)\n",
    "\n",
    "print(model.criterion_att.size)\n",
    "print(model.criterion_att.padding_idx)\n",
    "print(model.criterion_att.smoothing)\n",
    "print(model.criterion_att.normalize_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "growing-tooth",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "going-hungary",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "naughty-citizenship",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "experimental-emerald",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adverse-saskatchewan",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "speaking-shelf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from typing import Optional\n",
    "from typing import Tuple\n",
    "\n",
    "import paddle\n",
    "from paddle import nn\n",
    "from typeguard import check_argument_types\n",
    "\n",
    "from deepspeech.modules.activation import get_activation\n",
    "from deepspeech.modules.attention import MultiHeadedAttention\n",
    "from deepspeech.modules.attention import RelPositionMultiHeadedAttention\n",
    "from deepspeech.modules.conformer_convolution import ConvolutionModule\n",
    "from deepspeech.modules.embedding import PositionalEncoding\n",
    "from deepspeech.modules.embedding import RelPositionalEncoding\n",
    "from deepspeech.modules.encoder_layer import ConformerEncoderLayer\n",
    "from deepspeech.modules.encoder_layer import TransformerEncoderLayer\n",
    "from deepspeech.modules.mask import add_optional_chunk_mask\n",
    "from deepspeech.modules.mask import make_non_pad_mask\n",
    "from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward\n",
    "from deepspeech.modules.subsampling import Conv2dSubsampling4\n",
    "from deepspeech.modules.subsampling import Conv2dSubsampling6\n",
    "from deepspeech.modules.subsampling import Conv2dSubsampling8\n",
    "from deepspeech.modules.subsampling import LinearNoSubsampling\n",
    "\n",
    "class BaseEncoder(nn.Layer):\n",
    "    def __init__(\n",
    "            self,\n",
    "            input_size: int,\n",
    "            output_size: int=256,\n",
    "            attention_heads: int=4,\n",
    "            linear_units: int=2048,\n",
    "            num_blocks: int=6,\n",
    "            dropout_rate: float=0.1,\n",
    "            positional_dropout_rate: float=0.1,\n",
    "            attention_dropout_rate: float=0.0,\n",
    "            input_layer: str=\"conv2d\",\n",
    "            pos_enc_layer_type: str=\"abs_pos\",\n",
    "            normalize_before: bool=True,\n",
    "            concat_after: bool=False,\n",
    "            static_chunk_size: int=0,\n",
    "            use_dynamic_chunk: bool=False,\n",
    "            global_cmvn: paddle.nn.Layer=None,\n",
    "            use_dynamic_left_chunk: bool=False, ):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            input_size (int): input dim, d_feature\n",
    "            output_size (int): dimension of attention, d_model\n",
    "            attention_heads (int): the number of heads of multi head attention\n",
    "            linear_units (int): the hidden units number of position-wise feed\n",
    "                forward\n",
    "            num_blocks (int): the number of encoder blocks\n",
    "            dropout_rate (float): dropout rate\n",
    "            attention_dropout_rate (float): dropout rate in attention\n",
    "            positional_dropout_rate (float): dropout rate after adding\n",
    "                positional encoding\n",
    "            input_layer (str): input layer type.\n",
    "                optional [linear, conv2d, conv2d6, conv2d8]\n",
    "            pos_enc_layer_type (str): Encoder positional encoding layer type.\n",
    "                opitonal [abs_pos, scaled_abs_pos, rel_pos]\n",
    "            normalize_before (bool):\n",
    "                True: use layer_norm before each sub-block of a layer.\n",
    "                False: use layer_norm after each sub-block of a layer.\n",
    "            concat_after (bool): whether to concat attention layer's input\n",
    "                and output.\n",
    "                True: x -> x + linear(concat(x, att(x)))\n",
    "                False: x -> x + att(x)\n",
    "            static_chunk_size (int): chunk size for static chunk training and\n",
    "                decoding\n",
    "            use_dynamic_chunk (bool): whether use dynamic chunk size for\n",
    "                training or not, You can only use fixed chunk(chunk_size > 0)\n",
    "                or dyanmic chunk size(use_dynamic_chunk = True)\n",
    "            global_cmvn (Optional[paddle.nn.Layer]): Optional GlobalCMVN layer\n",
    "            use_dynamic_left_chunk (bool): whether use dynamic left chunk in\n",
    "                dynamic chunk training\n",
    "        \"\"\"\n",
    "        assert check_argument_types()\n",
    "        super().__init__()\n",
    "        self._output_size = output_size\n",
    "\n",
    "        if pos_enc_layer_type == \"abs_pos\":\n",
    "            pos_enc_class = PositionalEncoding\n",
    "        elif pos_enc_layer_type == \"rel_pos\":\n",
    "            pos_enc_class = RelPositionalEncoding\n",
    "        else:\n",
    "            raise ValueError(\"unknown pos_enc_layer: \" + pos_enc_layer_type)\n",
    "\n",
    "        if input_layer == \"linear\":\n",
    "            subsampling_class = LinearNoSubsampling\n",
    "        elif input_layer == \"conv2d\":\n",
    "            subsampling_class = Conv2dSubsampling4\n",
    "        elif input_layer == \"conv2d6\":\n",
    "            subsampling_class = Conv2dSubsampling6\n",
    "        elif input_layer == \"conv2d8\":\n",
    "            subsampling_class = Conv2dSubsampling8\n",
    "        else:\n",
    "            raise ValueError(\"unknown input_layer: \" + input_layer)\n",
    "\n",
    "        self.global_cmvn = global_cmvn\n",
    "        self.embed = subsampling_class(\n",
    "            idim=input_size,\n",
    "            odim=output_size,\n",
    "            dropout_rate=dropout_rate,\n",
    "            pos_enc_class=pos_enc_class(\n",
    "                d_model=output_size, dropout_rate=positional_dropout_rate), )\n",
    "\n",
    "        self.normalize_before = normalize_before\n",
    "        self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)\n",
    "        self.static_chunk_size = static_chunk_size\n",
    "        self.use_dynamic_chunk = use_dynamic_chunk\n",
    "        self.use_dynamic_left_chunk = use_dynamic_left_chunk\n",
    "\n",
    "    def output_size(self) -> int:\n",
    "        return self._output_size\n",
    "\n",
    "    def forward(\n",
    "            self,\n",
    "            xs: paddle.Tensor,\n",
    "            xs_lens: paddle.Tensor,\n",
    "            decoding_chunk_size: int=0,\n",
    "            num_decoding_left_chunks: int=-1,\n",
    "    ) -> Tuple[paddle.Tensor, paddle.Tensor]:\n",
    "        \"\"\"Embed positions in tensor.\n",
    "        Args:\n",
    "            xs: padded input tensor (B, L, D)\n",
    "            xs_lens: input length (B)\n",
    "            decoding_chunk_size: decoding chunk size for dynamic chunk\n",
    "                0: default for training, use random dynamic chunk.\n",
    "                <0: for decoding, use full chunk.\n",
    "                >0: for decoding, use fixed chunk size as set.\n",
    "            num_decoding_left_chunks: number of left chunks, this is for decoding,\n",
    "                the chunk size is decoding_chunk_size.\n",
    "                >=0: use num_decoding_left_chunks\n",
    "                <0: use all left chunks\n",
    "        Returns:\n",
    "            encoder output tensor, lens and mask\n",
    "        \"\"\"\n",
    "        masks = make_non_pad_mask(xs_lens).unsqueeze(1)  # (B, 1, L)\n",
    "\n",
    "        if self.global_cmvn is not None:\n",
    "            xs = self.global_cmvn(xs)\n",
    "        #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor\n",
    "        xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)\n",
    "        #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor\n",
    "        masks = masks.astype(paddle.bool)\n",
    "        #TODO(Hui Zhang): mask_pad = ~masks\n",
    "        mask_pad = masks.logical_not()\n",
    "        chunk_masks = add_optional_chunk_mask(\n",
    "            xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,\n",
    "            decoding_chunk_size, self.static_chunk_size,\n",
    "            num_decoding_left_chunks)\n",
    "        for layer in self.encoders:\n",
    "            xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
    "        if self.normalize_before:\n",
    "            xs = self.after_norm(xs)\n",
    "        # Here we assume the mask is not changed in encoder layers, so just\n",
    "        # return the masks before encoder layers, and the masks will be used\n",
    "        # for cross attention with decoder later\n",
    "        return xs, masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "sharp-municipality",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class ConformerEncoder(BaseEncoder):\n",
    "    \"\"\"Conformer encoder module.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            input_size: int,\n",
    "            output_size: int=256,\n",
    "            attention_heads: int=4,\n",
    "            linear_units: int=2048,\n",
    "            num_blocks: int=6,\n",
    "            dropout_rate: float=0.1,\n",
    "            positional_dropout_rate: float=0.1,\n",
    "            attention_dropout_rate: float=0.0,\n",
    "            input_layer: str=\"conv2d\",\n",
    "            pos_enc_layer_type: str=\"rel_pos\",\n",
    "            normalize_before: bool=True,\n",
    "            concat_after: bool=False,\n",
    "            static_chunk_size: int=0,\n",
    "            use_dynamic_chunk: bool=False,\n",
    "            global_cmvn: nn.Layer=None,\n",
    "            use_dynamic_left_chunk: bool=False,\n",
    "            positionwise_conv_kernel_size: int=1,\n",
    "            macaron_style: bool=True,\n",
    "            selfattention_layer_type: str=\"rel_selfattn\",\n",
    "            activation_type: str=\"swish\",\n",
    "            use_cnn_module: bool=True,\n",
    "            cnn_module_kernel: int=15,\n",
    "            causal: bool=False,\n",
    "            cnn_module_norm: str=\"batch_norm\", ):\n",
    "        \"\"\"Construct ConformerEncoder\n",
    "        Args:\n",
    "            input_size to use_dynamic_chunk, see in BaseEncoder\n",
    "            positionwise_conv_kernel_size (int): Kernel size of positionwise\n",
    "                conv1d layer.\n",
    "            macaron_style (bool): Whether to use macaron style for\n",
    "                positionwise layer.\n",
    "            selfattention_layer_type (str): Encoder attention layer type,\n",
    "                the parameter has no effect now, it's just for configure\n",
    "                compatibility.\n",
    "            activation_type (str): Encoder activation function type.\n",
    "            use_cnn_module (bool): Whether to use convolution module.\n",
    "            cnn_module_kernel (int): Kernel size of convolution module.\n",
    "            causal (bool): whether to use causal convolution or not.\n",
    "            cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']\n",
    "        \"\"\"\n",
    "        assert check_argument_types()\n",
    "        super().__init__(input_size, output_size, attention_heads, linear_units,\n",
    "                         num_blocks, dropout_rate, positional_dropout_rate,\n",
    "                         attention_dropout_rate, input_layer,\n",
    "                         pos_enc_layer_type, normalize_before, concat_after,\n",
    "                         static_chunk_size, use_dynamic_chunk, global_cmvn,\n",
    "                         use_dynamic_left_chunk)\n",
    "        activation = get_activation(activation_type)\n",
    "\n",
    "        # self-attention module definition\n",
    "        encoder_selfattn_layer = RelPositionMultiHeadedAttention\n",
    "        encoder_selfattn_layer_args = (attention_heads, output_size,\n",
    "                                       attention_dropout_rate)\n",
    "        # feed-forward module definition\n",
    "        positionwise_layer = PositionwiseFeedForward\n",
    "        positionwise_layer_args = (output_size, linear_units, dropout_rate,\n",
    "                                   activation)\n",
    "        # convolution module definition\n",
    "        convolution_layer = ConvolutionModule\n",
    "        convolution_layer_args = (output_size, cnn_module_kernel, activation,\n",
    "                                  cnn_module_norm, causal)\n",
    "\n",
    "        self.encoders = nn.ModuleList([\n",
    "            ConformerEncoderLayer(\n",
    "                size=output_size,\n",
    "                self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),\n",
    "                feed_forward=positionwise_layer(*positionwise_layer_args),\n",
    "                feed_forward_macaron=positionwise_layer(\n",
    "                    *positionwise_layer_args) if macaron_style else None,\n",
    "                conv_module=convolution_layer(*convolution_layer_args)\n",
    "                if use_cnn_module else None,\n",
    "                dropout_rate=dropout_rate,\n",
    "                normalize_before=normalize_before,\n",
    "                concat_after=concat_after) for _ in range(num_blocks)\n",
    "        ])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "tutorial-syndication",
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepspeech.frontend.utility import load_cmvn\n",
    "from deepspeech.modules.cmvn import GlobalCMVN\n",
    "\n",
    "configs=cfg.model\n",
    "mean, istd = load_cmvn(configs['cmvn_file'],\n",
    "                               configs['cmvn_file_type'])\n",
    "global_cmvn = GlobalCMVN(\n",
    "    paddle.to_tensor(mean, dtype=paddle.float),\n",
    "    paddle.to_tensor(istd, dtype=paddle.float))\n",
    "\n",
    "\n",
    "input_dim = configs['input_dim']\n",
    "vocab_size = configs['output_dim']\n",
    "encoder_type = configs.get('encoder', 'transformer')\n",
    "        \n",
    "encoder = ConformerEncoder(\n",
    "                input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "fuzzy-register",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "o = global_cmvn(feat)\n",
    "o2 = model.encoder.global_cmvn(feat)\n",
    "print(np.allclose(o.numpy(), o2.numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "explicit-triumph",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "humanitarian-belgium",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dying-proposal",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "honest-quick",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bound-cholesterol",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "viral-packaging",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "id": "balanced-locator",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[16, 1, 207], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[[True , True , True , ..., True , True , True ]],\n",
      "\n",
      "        [[True , True , True , ..., True , True , True ]],\n",
      "\n",
      "        [[True , True , True , ..., True , False, False]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[True , True , True , ..., False, False, False]],\n",
      "\n",
      "        [[True , True , True , ..., False, False, False]],\n",
      "\n",
      "        [[True , True , True , ..., False, False, False]]])\n"
     ]
    }
   ],
   "source": [
    "from deepspeech.modules.mask import make_non_pad_mask\n",
    "from deepspeech.modules.mask import make_pad_mask\n",
    "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
    "print(masks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 204,
   "id": "induced-proposition",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[16, 207, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[[-0.53697914, -0.19910523, -0.34997201, ..., -0.82427669, -1.02650309, -0.96300691],\n",
      "         [-0.04464225,  0.23176001, -0.32538742, ..., -0.90158713, -1.03248465, -0.75986791],\n",
      "         [ 0.50035292,  0.22691160, -0.73052198, ..., -1.00552964, -0.87123060, -1.03062117],\n",
      "         ...,\n",
      "         [-0.40023831, -0.14325078, -0.57947433, ..., -1.07178426, -1.28059900, -1.05180073],\n",
      "         [ 0.15755332, -0.00184949, -0.28702953, ..., -1.10898709, -0.94518697, -0.72506356],\n",
      "         [-0.47520429, -1.39415145, -0.25754252, ..., -1.13649082, -1.19430351, -1.22903371]],\n",
      "\n",
      "        [[ 0.95454037,  0.36427975, -1.38908529, ..., -1.16366839, -1.28453600, -1.20151031],\n",
      "         [-0.08573537, -1.05785275, -0.89172721, ..., -0.96440506, -1.12547100, -1.25990939],\n",
      "         [ 0.47653601,  0.32886592, -0.59200549, ..., -1.19421589, -1.14302588, -1.02422845],\n",
      "         ...,\n",
      "         [-0.47431335, -0.33558893, -0.72325647, ..., -1.45058632, -1.39574063, -1.04641151],\n",
      "         [ 0.36112556,  0.10380996, -1.15994537, ..., -1.04394984, -1.02212358, -1.02083635],\n",
      "         [-1.27172923, -2.14601755, -0.75676596, ..., -0.97822225, -0.93785471, -1.03707945]],\n",
      "\n",
      "        [[-1.54652190, -1.01517177, -0.88900733, ..., -0.48522446, -0.75163364, -0.67765164],\n",
      "         [-0.76100892, -0.73351598, -0.91587651, ..., -0.24835993, -0.58927339, -0.73722762],\n",
      "         [-0.02471367,  0.17015894, -0.42326337, ..., -0.33203802, -0.76695800, -0.71651691],\n",
      "         ...,\n",
      "         [-1.70319796, -1.25910866, -1.14492917, ..., -1.18101490, -1.11631835, -0.93108195],\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[ 0.64982772,  0.26116797, -0.84196597, ..., -0.87213463, -1.10728693, -1.32531130],\n",
      "         [ 0.35391113, -0.01584581, -0.40424931, ..., -0.99173468, -1.07270539, -1.19239008],\n",
      "         [ 0.37704495, -0.06278508, -0.11467686, ..., -1.10212946, -1.09524000, -1.11815071],\n",
      "         ...,\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n",
      "\n",
      "        [[ 0.04445776, -0.17546852, -0.67475224, ..., -0.49801198, -0.56782746, -0.77852231],\n",
      "         [-1.34279025, -0.80342549, -0.90457231, ..., -0.65901577, -0.72549772, -0.62796098],\n",
      "         [-0.76252806, -0.13071291, -0.13280024, ..., -0.56132573, -0.60587686, -0.72114766],\n",
      "         ...,\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n",
      "\n",
      "        [[-1.07980299, -1.08341801, -1.17969072, ..., -0.17757270, -0.43746525, -0.04000654],\n",
      "         [ 0.92353648,  0.63770926, -0.52810186, ..., -0.12927933, -0.20342292,  0.16655664],\n",
      "         [ 0.49337494, -0.00911332, -0.73301607, ...,  0.10074048, -0.09811471, -0.00923573],\n",
      "         ...,\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
      "         [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]]])\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "print(xs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "id": "cutting-julian",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[16, 256, 51, 19], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [[[[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.00209083],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.01194306, 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.04610471, 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.00967231, 0.04613467, 0.        ]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0.22816099, 0.24614786, 0.25304127, ..., 0.20401822, 0.23248228, 0.31190544],\n",
      "          [0.13587360, 0.28877240, 0.27991283, ..., 0.19210319, 0.20346391, 0.19934426],\n",
      "          [0.25739068, 0.39348233, 0.27877361, ..., 0.27482539, 0.19302306, 0.23810163],\n",
      "          ...,\n",
      "          [0.11939213, 0.28473237, 0.33082074, ..., 0.23838061, 0.22104350, 0.23905794],\n",
      "          [0.17387670, 0.20402060, 0.40263173, ..., 0.24782266, 0.26742202, 0.15426503],\n",
      "          [0.        , 0.29080707, 0.27725950, ..., 0.17539823, 0.18478745, 0.22483408]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.35446781, 0.38861471, 0.39724261, ..., 0.38680089, 0.33568040, 0.34552398],\n",
      "          [0.41739127, 0.51038563, 0.41729912, ..., 0.33992639, 0.37081629, 0.35109508],\n",
      "          [0.36116859, 0.40744874, 0.48490953, ..., 0.34848654, 0.32321057, 0.35188958],\n",
      "          ...,\n",
      "          [0.23143977, 0.38021481, 0.51526314, ..., 0.36499465, 0.37411752, 0.39986172],\n",
      "          [0.34678638, 0.40238205, 0.50076538, ..., 0.36184520, 0.31596646, 0.36334658],\n",
      "          [0.36498138, 0.37943166, 0.51718897, ..., 0.31798238, 0.33656698, 0.34130475]]],\n",
      "\n",
      "\n",
      "        [[[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.01456045, 0.09447514, 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.01500242, 0.02963220, 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.03295187, 0.        , 0.        , ..., 0.04584959, 0.02043908, 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.04425837],\n",
      "          [0.        , 0.        , 0.02556529, ..., 0.        , 0.00900441, 0.04908358]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.11141267, 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0.33696529, 0.38526866, 0.32900479, ..., 0.28703830, 0.23351061, 0.19004467],\n",
      "          [0.13575366, 0.35783342, 0.33573425, ..., 0.22081660, 0.15854910, 0.13587447],\n",
      "          [0.21928655, 0.28900093, 0.28255141, ..., 0.20602837, 0.23927397, 0.21909429],\n",
      "          ...,\n",
      "          [0.23291890, 0.39096734, 0.36399242, ..., 0.20598020, 0.25373828, 0.23137446],\n",
      "          [0.18739152, 0.30793777, 0.30296701, ..., 0.27250600, 0.25191751, 0.20836820],\n",
      "          [0.22454213, 0.41402060, 0.54082996, ..., 0.31874508, 0.25079906, 0.25938687]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.26456982, 0.49519050, 0.56702250, ..., 0.30954638, 0.35292268, 0.32668519],\n",
      "          [0.21576807, 0.51833367, 0.49183372, ..., 0.36043224, 0.38523889, 0.36154741],\n",
      "          [0.20067888, 0.42784205, 0.52817714, ..., 0.31871423, 0.32452232, 0.31036487],\n",
      "          ...,\n",
      "          [0.49855131, 0.51001430, 0.52278662, ..., 0.36450142, 0.34338164, 0.33602941],\n",
      "          [0.41233343, 0.55517823, 0.52827710, ..., 0.40675971, 0.33873138, 0.36724189],\n",
      "          [0.40820011, 0.46187383, 0.47338152, ..., 0.38690975, 0.36039269, 0.38022059]]],\n",
      "\n",
      "\n",
      "        [[[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.        , 0.00578516, 0.        , ..., 0.00748384, 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.03035110, 0.        , 0.00026720],\n",
      "          [0.00094807, 0.        , 0.        , ..., 0.00795512, 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.02032628, 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.01080076, 0.        ],\n",
      "          [0.18470290, 0.        , 0.        , ..., 0.05058352, 0.09475817, 0.05914564]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0.38708323, 0.28021947, 0.35892880, ..., 0.16595127, 0.16031364, 0.21136315],\n",
      "          [0.15595171, 0.30544323, 0.24666184, ..., 0.22675267, 0.25765014, 0.19682154],\n",
      "          [0.29517862, 0.41209796, 0.20063159, ..., 0.17595036, 0.22536841, 0.22214051],\n",
      "          ...,\n",
      "          [0.24744980, 0.26258564, 0.38654143, ..., 0.23620218, 0.23157144, 0.18514194],\n",
      "          [0.25714791, 0.29592845, 0.47744542, ..., 0.23545510, 0.25072727, 0.20976165],\n",
      "          [1.20154655, 0.84644288, 0.73385584, ..., 1.02517247, 0.95309550, 1.00134516]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.45013186, 0.47484034, 0.40540054, ..., 0.19346163, 0.17825794, 0.14776605],\n",
      "          [0.47545874, 0.48186573, 0.36760187, ..., 0.27809089, 0.32997063, 0.32337096],\n",
      "          [0.46160024, 0.40050328, 0.39060861, ..., 0.36612910, 0.35242686, 0.29738861],\n",
      "          ...,\n",
      "          [0.55148494, 0.51017821, 0.40132499, ..., 0.38948193, 0.35737294, 0.33088297],\n",
      "          [0.41972569, 0.45475486, 0.45320493, ..., 0.38343129, 0.40125814, 0.36180776],\n",
      "          [0.34279808, 0.31606171, 0.44701228, ..., 0.21665487, 0.23984617, 0.23903391]]],\n",
      "\n",
      "\n",
      "        ...,\n",
      "\n",
      "\n",
      "        [[[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.04178291, 0.        , 0.01580476, ..., 0.        , 0.02250817, 0.        ],\n",
      "          [0.04323414, 0.07786420, 0.        , ..., 0.01634724, 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.03209178, 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.13563479, 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0.        , 0.25187218, 0.24979387, ..., 0.24774717, 0.22354351, 0.19149347],\n",
      "          [0.16540922, 0.19585510, 0.19812922, ..., 0.27344131, 0.20928150, 0.26150429],\n",
      "          [0.10494646, 0.06329897, 0.33843631, ..., 0.25138417, 0.12470355, 0.23926635],\n",
      "          ...,\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.11428106, 0.45667490, 0.46820879, ..., 0.32057840, 0.33578536, 0.39012644],\n",
      "          [0.10441341, 0.45739070, 0.46107352, ..., 0.38467997, 0.38291249, 0.36685589],\n",
      "          [0.19867736, 0.35519636, 0.44313061, ..., 0.40679252, 0.38067645, 0.30645671],\n",
      "          ...,\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n",
      "\n",
      "\n",
      "        [[[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.02465414, 0.        , 0.        , ..., 0.        , 0.        , 0.03390232],\n",
      "          [0.        , 0.        , 0.01830704, ..., 0.05166877, 0.00948385, 0.07453502],\n",
      "          [0.09921519, 0.        , 0.01587192, ..., 0.01620276, 0.05140074, 0.00192392],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0.40034360, 0.25306445, 0.20217699, ..., 0.09816189, 0.07064310, 0.04974059],\n",
      "          [0.12567598, 0.21030979, 0.11181555, ..., 0.04278110, 0.11968569, 0.12005232],\n",
      "          [0.28786880, 0.24030517, 0.22565845, ..., 0.        , 0.06418110, 0.05872961],\n",
      "          ...,\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.38404641, 0.30990323, 0.37156230, ..., 0.18125033, 0.15050662, 0.19619957],\n",
      "          [0.47285745, 0.40528792, 0.39718056, ..., 0.24709940, 0.04565683, 0.11500744],\n",
      "          [0.32620737, 0.30072594, 0.30477354, ..., 0.23529193, 0.21356541, 0.16985542],\n",
      "          ...,\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n",
      "\n",
      "\n",
      "        [[[0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.03343770, 0.00123780, 0.05297198, ..., 0.07271163, 0.08656286, 0.14493589],\n",
      "          [0.11043239, 0.06143146, 0.06362963, ..., 0.08127750, 0.06259022, 0.08315435],\n",
      "          [0.01767678, 0.00201111, 0.07875030, ..., 0.06963293, 0.08979890, 0.05326346],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.10033827, 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.15627117, 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.05144687, 0.        , 0.        , ..., 0.        , 0.        , 0.00436414],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         ...,\n",
      "\n",
      "         [[0.25142455, 0.45964020, 0.37346074, ..., 0.04763087, 0.        , 0.        ],\n",
      "          [0.19760093, 0.26626948, 0.11190540, ..., 0.03044968, 0.        , 0.        ],\n",
      "          [0.16340607, 0.32938001, 0.25689697, ..., 0.05569421, 0.        , 0.        ],\n",
      "          ...,\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
      "          [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n",
      "\n",
      "         [[0.        , 0.        , 0.        , ..., 0.        , 0.02218930, 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.02848953],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          ...,\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ],\n",
      "          [0.        , 0.        , 0.        , ..., 0.        , 0.        , 0.        ]],\n",
      "\n",
      "         [[0.25810039, 0.63016868, 0.37037861, ..., 0.18704373, 0.08269356, 0.09912672],\n",
      "          [0.17292863, 0.50678611, 0.40738991, ..., 0.16006103, 0.11725381, 0.09940521],\n",
      "          [0.24175072, 0.41616210, 0.41256818, ..., 0.13519743, 0.07912572, 0.12846369],\n",
      "          ...,\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
      "          [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]]])\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
    "\n",
    "\n",
    "#xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
    "# print(xs)\n",
    "\n",
    "x = xs.unsqueeze(1)\n",
    "x = model.encoder.embed.conv(x)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "id": "friendly-nightlife",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [[[-0.03426375,  0.14291267, -0.06718873, ...,  0.09064753,  0.01809387, -0.04340880],\n",
      "         [-0.05007839,  0.11054724, -0.10399298, ...,  0.11457238,  0.04244684, -0.01249714],\n",
      "         [-0.10695291,  0.16910909, -0.08352133, ...,  0.07710276,  0.01168563, -0.03584499],\n",
      "         ...,\n",
      "         [-0.06060536,  0.14455931, -0.05470302, ...,  0.05364908,  0.03033342, -0.02610814],\n",
      "         [-0.08505894,  0.13611752, -0.11132983, ...,  0.13079923,  0.01580139, -0.02281028],\n",
      "         [-0.10604677,  0.14714901, -0.10885533, ...,  0.08543444,  0.03719445, -0.04634233]],\n",
      "\n",
      "        [[-0.12392755,  0.14486063, -0.05674079, ...,  0.02573164,  0.03128851,  0.00545091],\n",
      "         [-0.04775286,  0.08473608, -0.08507854, ...,  0.04573154,  0.04240163,  0.01053247],\n",
      "         [-0.05940291,  0.10023535, -0.08143730, ...,  0.03596500,  0.01673085,  0.02089563],\n",
      "         ...,\n",
      "         [-0.09222981,  0.15823206, -0.07700447, ...,  0.08122957,  0.03136991, -0.00646474],\n",
      "         [-0.07331756,  0.14482647, -0.07838815, ...,  0.10869440,  0.01356864, -0.02777974],\n",
      "         [-0.07937264,  0.20143102, -0.05544947, ...,  0.10287814,  0.00608235, -0.04799180]],\n",
      "\n",
      "        [[-0.03670349,  0.08931590, -0.08718812, ...,  0.01314050,  0.00642052,  0.00573716],\n",
      "         [ 0.01089254,  0.11146393, -0.10263617, ...,  0.05070438,  0.01960694,  0.03521532],\n",
      "         [-0.02182280,  0.11443964, -0.06678198, ...,  0.04327708,  0.00861394,  0.02871092],\n",
      "         ...,\n",
      "         [-0.06792898,  0.14376275, -0.07899005, ...,  0.11248926,  0.03208683, -0.03264240],\n",
      "         [-0.07884051,  0.17024788, -0.08583611, ...,  0.09028331,  0.03588808, -0.02075090],\n",
      "         [-0.13792302,  0.27163863, -0.23930418, ...,  0.13391261,  0.07521040, -0.08621951]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[-0.02446348,  0.11595841, -0.03591986, ...,  0.06288970,  0.02895011, -0.06532725],\n",
      "         [-0.05378424,  0.12607370, -0.09023033, ...,  0.09078894,  0.01035743,  0.03701983],\n",
      "         [-0.04566649,  0.14275314, -0.06686870, ...,  0.09890588, -0.00612222,  0.03439377],\n",
      "         ...,\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698],\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698],\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698]],\n",
      "\n",
      "        [[-0.01012144,  0.03909408, -0.07077143, ...,  0.00452683, -0.01377654,  0.02897627],\n",
      "         [-0.00519154,  0.03594019, -0.06831125, ...,  0.05693541, -0.00406374,  0.04561640],\n",
      "         [-0.01762631,  0.00500899, -0.05886075, ...,  0.02112178, -0.00729015,  0.02782153],\n",
      "         ...,\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698],\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698],\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698]],\n",
      "\n",
      "        [[-0.03411558, -0.04318277, -0.08497842, ..., -0.04886402,  0.04296734,  0.06151697],\n",
      "         [ 0.00263296, -0.06913657, -0.08993219, ..., -0.00149064,  0.05696633,  0.03304394],\n",
      "         [-0.01818341, -0.01178640, -0.09679577, ..., -0.00870231,  0.00362198,  0.01916483],\n",
      "         ...,\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698],\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698],\n",
      "         [-0.31763062,  0.53700209, -0.26335421, ...,  0.39182857,  0.00337184, -0.18293698]]])\n",
      "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
      "       [[[-0.54821998,  2.28660274, -1.07501972, ...,  1.45036042,  0.28950194, -0.69454080],\n",
      "         [-0.80125421,  1.76875579, -1.66388774, ...,  1.83315802,  0.67914939, -0.19995420],\n",
      "         [-1.71124649,  2.70574546, -1.33634126, ...,  1.23364413,  0.18697014, -0.57351983],\n",
      "         ...,\n",
      "         [-0.96968573,  2.31294894, -0.87524825, ...,  0.85838526,  0.48533469, -0.41773027],\n",
      "         [-1.36094308,  2.17788029, -1.78127730, ...,  2.09278774,  0.25282228, -0.36496443],\n",
      "         [-1.69674826,  2.35438418, -1.74168527, ...,  1.36695099,  0.59511113, -0.74147725]],\n",
      "\n",
      "        [[-1.98284078,  2.31777000, -0.90785271, ...,  0.41170627,  0.50061619,  0.08721463],\n",
      "         [-0.76404583,  1.35577726, -1.36125672, ...,  0.73170459,  0.67842603,  0.16851945],\n",
      "         [-0.95044655,  1.60376561, -1.30299675, ...,  0.57544005,  0.26769355,  0.33433008],\n",
      "         ...,\n",
      "         [-1.47567701,  2.53171301, -1.23207152, ...,  1.29967308,  0.50191855, -0.10343577],\n",
      "         [-1.17308092,  2.31722355, -1.25421047, ...,  1.73911047,  0.21709818, -0.44447583],\n",
      "         [-1.26996231,  3.22289634, -0.88719147, ...,  1.64605021,  0.09731755, -0.76786882]],\n",
      "\n",
      "        [[-0.58725590,  1.42905438, -1.39500988, ...,  0.21024795,  0.10272825,  0.09179455],\n",
      "         [ 0.17428070,  1.78342295, -1.64217877, ...,  0.81127012,  0.31371105,  0.56344515],\n",
      "         [-0.34916472,  1.83103430, -1.06851172, ...,  0.69243336,  0.13782299,  0.45937473],\n",
      "         ...,\n",
      "         [-1.08686376,  2.30020404, -1.26384079, ...,  1.79982817,  0.51338923, -0.52227837],\n",
      "         [-1.26144814,  2.72396612, -1.37337780, ...,  1.44453299,  0.57420933, -0.33201432],\n",
      "         [-2.20676827,  4.34621811, -3.82886696, ...,  2.14260173,  1.20336640, -1.37951219]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[-0.39141566,  1.85533464, -0.57471782, ...,  1.00623512,  0.46320182, -1.04523599],\n",
      "         [-0.86054784,  2.01717925, -1.44368529, ...,  1.45262301,  0.16571884,  0.59231722],\n",
      "         [-0.73066384,  2.28405023, -1.06989920, ...,  1.58249414, -0.09795550,  0.55030036],\n",
      "         ...,\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170],\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170],\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170]],\n",
      "\n",
      "        [[-0.16194311,  0.62550521, -1.13234293, ...,  0.07242929, -0.22042468,  0.46362036],\n",
      "         [-0.08306468,  0.57504302, -1.09298003, ...,  0.91096652, -0.06501988,  0.72986233],\n",
      "         [-0.28202093,  0.08014385, -0.94177192, ...,  0.33794850, -0.11664233,  0.44514441],\n",
      "         ...,\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170],\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170],\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170]],\n",
      "\n",
      "        [[-0.54584920, -0.69092435, -1.35965478, ..., -0.78182435,  0.68747747,  0.98427159],\n",
      "         [ 0.04212743, -1.10618520, -1.43891501, ..., -0.02385022,  0.91146135,  0.52870303],\n",
      "         [-0.29093450, -0.18858244, -1.54873240, ..., -0.13923697,  0.05795169,  0.30663735],\n",
      "         ...,\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170],\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170],\n",
      "         [-5.08208990,  8.59203339, -4.21366739, ...,  6.26925707,  0.05394945, -2.92699170]]])\n",
      "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[[ 0.        ,  1.        ,  0.        , ...,  1.        ,  0.        ,  1.        ],\n",
      "         [ 0.84147102,  0.54030228,  0.80196184, ...,  1.        ,  0.00010746,  1.        ],\n",
      "         [ 0.90929747, -0.41614681,  0.95814437, ...,  1.        ,  0.00021492,  1.        ],\n",
      "         ...,\n",
      "         [-0.76825470, -0.64014435,  0.63279730, ...,  0.99998462,  0.00515809,  0.99998671],\n",
      "         [-0.95375264,  0.30059254,  0.99899054, ...,  0.99998397,  0.00526555,  0.99998611],\n",
      "         [-0.26237485,  0.96496606,  0.56074661, ...,  0.99998331,  0.00537301,  0.99998558]]])\n"
     ]
    }
   ],
   "source": [
    "b, c, t, f = paddle.shape(x)\n",
    "x = model.encoder.embed.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))\n",
    "print(x)\n",
    "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n",
    "print(x)\n",
    "print(pos_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "id": "guilty-cache",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[[ 0.        ,  1.        ,  0.        , ...,  1.        ,  0.        ,  1.        ],\n",
      "         [ 0.84147102,  0.54030228,  0.80196184, ...,  1.        ,  0.00010746,  1.        ],\n",
      "         [ 0.90929747, -0.41614681,  0.95814437, ...,  1.        ,  0.00021492,  1.        ],\n",
      "         ...,\n",
      "         [-0.76825470, -0.64014435,  0.63279730, ...,  0.99998462,  0.00515809,  0.99998671],\n",
      "         [-0.95375264,  0.30059254,  0.99899054, ...,  0.99998397,  0.00526555,  0.99998611],\n",
      "         [-0.26237485,  0.96496606,  0.56074661, ...,  0.99998331,  0.00537301,  0.99998558]]])\n"
     ]
    }
   ],
   "source": [
    "print(pos_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 208,
   "id": "iraqi-payday",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[ 0.0000000e+00  1.0000000e+00  0.0000000e+00 ...  1.0000000e+00\n",
      "    0.0000000e+00  1.0000000e+00]\n",
      "  [ 8.4147096e-01  5.4030234e-01  8.0196178e-01 ...  1.0000000e+00\n",
      "    1.0746076e-04  1.0000000e+00]\n",
      "  [ 9.0929741e-01 -4.1614684e-01  9.5814437e-01 ...  1.0000000e+00\n",
      "    2.1492151e-04  1.0000000e+00]\n",
      "  ...\n",
      "  [ 9.5625257e-01 -2.9254240e-01  4.8925215e-01 ...  8.3807874e-01\n",
      "    5.1154459e-01  8.5925674e-01]\n",
      "  [ 2.7049953e-01 -9.6272010e-01  9.9170387e-01 ...  8.3801574e-01\n",
      "    5.1163691e-01  8.5920173e-01]\n",
      "  [-6.6394955e-01 -7.4777740e-01  6.9544029e-01 ...  8.3795273e-01\n",
      "    5.1172924e-01  8.5914677e-01]]]\n",
      "[1, 5000, 256]\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import math\n",
    "import numpy as np\n",
    "\n",
    "max_len=5000\n",
    "d_model=256\n",
    "\n",
    "pe = torch.zeros(max_len, d_model)\n",
    "position = torch.arange(0, max_len,\n",
    "                        dtype=torch.float32).unsqueeze(1)\n",
    "toruch_position = position\n",
    "div_term = torch.exp(\n",
    "    torch.arange(0, d_model, 2, dtype=torch.float32) *\n",
    "    -(math.log(10000.0) / d_model))\n",
    "tourch_div_term = div_term.cpu().detach().numpy()\n",
    "\n",
    "torhc_sin = torch.sin(position * div_term)\n",
    "torhc_cos = torch.cos(position * div_term)\n",
    "\n",
    "np_sin = np.sin((position * div_term).cpu().detach().numpy())\n",
    "np_cos = np.cos((position * div_term).cpu().detach().numpy())\n",
    "pe[:, 0::2] = torhc_sin\n",
    "pe[:, 1::2] = torhc_cos\n",
    "pe = pe.unsqueeze(0) \n",
    "tourch_pe = pe.cpu().detach().numpy()\n",
    "print(tourch_pe)\n",
    "bak_pe = model.encoder.embed.pos_enc.pe\n",
    "print(bak_pe.shape)\n",
    "model.encoder.embed.pos_enc.pe = paddle.to_tensor(tourch_pe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 210,
   "id": "exempt-cloud",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
    "\n",
    "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
    "#print(xs)\n",
    "data = np.load(\".notebook/enc_embed.npz\")\n",
    "torch_pos_emb=data['pos_emb']\n",
    "torch_xs = data['embed_out']\n",
    "print(np.allclose(xs.numpy(), torch_xs))\n",
    "print(np.allclose(pos_emb.numpy(), torch_pos_emb))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "composite-involvement",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "id": "handed-harris",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "False\n",
      "True\n",
      "[256, 2048]\n",
      "[2048]\n",
      "[2048, 256]\n",
      "[256]\n",
      "--------ff-------\n",
      "True\n",
      "False\n",
      "False\n",
      "False\n",
      "False\n",
      "True\n",
      "linear_714.w_0 True\n",
      "linear_714.b_0 True\n",
      "linear_715.w_0 True\n",
      "linear_715.b_0 True\n",
      "False\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
    "\n",
    "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
    "masks = masks.astype(paddle.bool)\n",
    "mask_pad = masks.logical_not()\n",
    "decoding_chunk_size=0\n",
    "num_decoding_left_chunks=-1\n",
    "chunk_masks = add_optional_chunk_mask(\n",
    "            xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n",
    "            decoding_chunk_size, model.encoder.static_chunk_size,\n",
    "            num_decoding_left_chunks)\n",
    "\n",
    "#print(chunk_masks)\n",
    "data = np.load(\".notebook/enc_embed.npz\")\n",
    "torch_pos_emb=data['pos_emb']\n",
    "torch_xs = data['embed_out']\n",
    "torch_chunk_masks = data['chunk_masks']\n",
    "torch_mask_pad = data['mask_pad']\n",
    "print(np.allclose(xs.numpy(), torch_xs))\n",
    "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n",
    "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n",
    "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n",
    "\n",
    "for layer in model.encoder.encoders:\n",
    "    #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
    "    print(layer.feed_forward_macaron is not None)\n",
    "    print(layer.normalize_before)\n",
    "    \n",
    "    data = np.load('.notebook/enc_0_norm_ff.npz')\n",
    "    t_norm_ff = data['norm_ff']\n",
    "    t_xs = data['xs']\n",
    "   \n",
    "    \n",
    "    x = xs\n",
    "    print(np.allclose(t_xs, x.numpy()))\n",
    "    residual = x\n",
    "    print(np.allclose(t_xs, residual.numpy()))\n",
    "    x_nrom = layer.norm_ff_macaron(x)\n",
    "    print(np.allclose(t.numpy(), x_nrom.numpy()))\n",
    "    print(np.allclose(t_norm_ff, x_nrom.numpy()))\n",
    "#     for n, p in layer.norm_ff_macaron.state_dict().items():\n",
    "#         print(n, p)\n",
    "#         pass\n",
    "\n",
    "    layer.eval()\n",
    "    x_nrom = paddle.to_tensor(t_norm_ff)\n",
    "    print(np.allclose(t_norm_ff, x_nrom.numpy()))\n",
    "    x = residual + layer.ff_scale * layer.feed_forward_macaron(x_nrom)\n",
    "    \n",
    "    ps=[]\n",
    "    for n, p in layer.feed_forward_macaron.state_dict().items():\n",
    "         #print(n, p)\n",
    "         ps.append(p)\n",
    "         print(p.shape)\n",
    "         pass\n",
    "\n",
    "    x_nrom = paddle.to_tensor(t_norm_ff)\n",
    "    ff_l_x = layer.feed_forward_macaron.w_1(x_nrom)\n",
    "    ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n",
    "    ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n",
    "    data = np.load('.notebook/enc_0_ff_out.npz', allow_pickle=True)\n",
    "    t_norm_ff = data['norm_ff']\n",
    "    t_ff_out = data['ff_out']\n",
    "    t_ff_l_x = data['ff_l_x']\n",
    "    t_ff_l_a_x = data['ff_l_a_x']\n",
    "    t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
    "    t_ps = data['ps']\n",
    "    \n",
    "    print(\"--------ff-------\")\n",
    "    print(np.allclose(x_nrom.numpy(), t_norm_ff))\n",
    "    print(np.allclose(x.numpy(), t_ff_out))\n",
    "    print(np.allclose(ff_l_x.numpy(), t_ff_l_x))\n",
    "    print(np.allclose(ff_l_a_x.numpy(), t_ff_l_a_x))\n",
    "    print(np.allclose(ff_l_a_l_x.numpy(), t_ff_l_a_l_x))\n",
    "    \n",
    "    print(np.allclose(ff_l_x.numpy(), t_ff_l_x, atol=1e-6))\n",
    "    for p, t_p in zip(ps, t_ps):\n",
    "        print(p.name, np.allclose(p.numpy(), t_p.T))\n",
    "    \n",
    "    \n",
    "#     residual = x\n",
    "#     x = layer.norm_mha(x)\n",
    "#     x_q = x\n",
    "    \n",
    "    data = np.load('.notebook/enc_0_selattn_out.npz', allow_pickle=True)\n",
    "    tx_q = data['x_q']\n",
    "    tx = data['x']\n",
    "    tpos_emb=data['pos_emb']\n",
    "    tmask=data['mask']\n",
    "    tt_x_att=data['x_att']\n",
    "    x_q = paddle.to_tensor(tx_q)\n",
    "    x = paddle.to_tensor(tx)\n",
    "    pos_emb = paddle.to_tensor(tpos_emb)\n",
    "    mask = paddle.to_tensor(tmask)\n",
    "    \n",
    "    x_att = layer.self_attn(x_q, x, x, pos_emb, mask)\n",
    "    print(np.allclose(x_att.numpy(), t_x_att))\n",
    "    print(np.allclose(x_att.numpy(), t_x_att, atol=1e-6))\n",
    "    \n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "id": "sonic-thumb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n",
      "False\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
    "\n",
    "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
    "masks = masks.astype(paddle.bool)\n",
    "mask_pad = masks.logical_not()\n",
    "decoding_chunk_size=0\n",
    "num_decoding_left_chunks=-1\n",
    "chunk_masks = add_optional_chunk_mask(\n",
    "            xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n",
    "            decoding_chunk_size, model.encoder.static_chunk_size,\n",
    "            num_decoding_left_chunks)\n",
    "\n",
    "#print(chunk_masks)\n",
    "data = np.load(\".notebook/enc_embed.npz\")\n",
    "torch_pos_emb=data['pos_emb']\n",
    "torch_xs = data['embed_out']\n",
    "torch_chunk_masks = data['chunk_masks']\n",
    "torch_mask_pad = data['mask_pad']\n",
    "print(np.allclose(xs.numpy(), torch_xs))\n",
    "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n",
    "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n",
    "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n",
    "\n",
    "\n",
    "for layer in model.encoder.encoders:\n",
    "    xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
    "    break\n",
    "data = np.load('.notebook/enc_0.npz')\n",
    "torch_xs = data['enc_0']\n",
    "print(np.allclose(xs.numpy(), torch_xs))\n",
    "print(np.allclose(xs.numpy(), torch_xs, atol=1e-6))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "id": "brave-latino",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n",
      "--------layers_______\n",
      "False\n",
      "True\n",
      "[[-0.70194244  0.56254214  0.6880346  ...  1.1237319   0.7803924\n",
      "   1.1369387 ]\n",
      " [-0.7787783   0.3912667   0.71887773 ...  1.251882    0.886168\n",
      "   1.3173451 ]\n",
      " [-0.95908964  0.6346029   0.87671334 ...  0.98183745  0.7440111\n",
      "   1.2903278 ]\n",
      " ...\n",
      " [-1.0732255   0.67236906  0.92303115 ...  0.9075458   0.8176712\n",
      "   1.3239655 ]\n",
      " [-1.1654118   0.6819967   0.6939453  ...  1.2238353   0.8028295\n",
      "   1.4506507 ]\n",
      " [-1.2732092   0.7145806   0.75819594 ...  0.94154835  0.8774845\n",
      "   1.2623049 ]]\n",
      "xxxxxx\n",
      "[[-0.7019424   0.56254166  0.6880345  ...  1.1237322   0.78039217\n",
      "   1.1369387 ]\n",
      " [-0.778778    0.39126638  0.7188779  ...  1.2518823   0.8861681\n",
      "   1.3173454 ]\n",
      " [-0.9590891   0.6346026   0.87671363 ...  0.9818373   0.74401116\n",
      "   1.2903274 ]\n",
      " ...\n",
      " [-1.0732253   0.6723689   0.9230311  ...  0.9075457   0.8176713\n",
      "   1.3239657 ]\n",
      " [-1.165412    0.6819976   0.69394535 ...  1.2238353   0.80282927\n",
      "   1.4506509 ]\n",
      " [-1.273209    0.71458095  0.75819623 ...  0.9415484   0.8774842\n",
      "   1.2623055 ]]\n"
     ]
    }
   ],
   "source": [
    "xs = model.encoder.global_cmvn(feat)\n",
    "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
    "\n",
    "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
    "masks = masks.astype(paddle.bool)\n",
    "mask_pad = masks.logical_not()\n",
    "decoding_chunk_size=0\n",
    "num_decoding_left_chunks=-1\n",
    "chunk_masks = add_optional_chunk_mask(\n",
    "            xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n",
    "            decoding_chunk_size, model.encoder.static_chunk_size,\n",
    "            num_decoding_left_chunks)\n",
    "\n",
    "#print(chunk_masks)\n",
    "data = np.load(\".notebook/enc_embed.npz\")\n",
    "torch_pos_emb=data['pos_emb']\n",
    "torch_xs = data['embed_out']\n",
    "torch_chunk_masks = data['chunk_masks']\n",
    "torch_mask_pad = data['mask_pad']\n",
    "print(np.allclose(xs.numpy(), torch_xs))\n",
    "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n",
    "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n",
    "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n",
    "\n",
    "print(\"--------layers_______\")\n",
    "i =0\n",
    "for layer in model.encoder.encoders:\n",
    "    xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
    "    i+=1\n",
    "#     if i == 2:\n",
    "#         data = np.load('.notebook/enc_2.npz')\n",
    "#         torch_xs = data['enc_2']\n",
    "#         print(np.allclose(xs.numpy(), torch_xs))\n",
    "#         print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n",
    "#         print(xs[0].numpy())\n",
    "#         print('xxxxxx')\n",
    "#         print(torch_xs[0])\n",
    "#         print('----i==2')\n",
    "data = np.load('.notebook/enc_all.npz')\n",
    "torch_xs = data['enc_all']\n",
    "print(np.allclose(xs.numpy(), torch_xs))\n",
    "print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n",
    "print(xs[0].numpy())\n",
    "print('xxxxxx')\n",
    "print(torch_xs[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "municipal-stock",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "id": "macro-season",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.7019424   0.5625421   0.68803453 ...  1.1237317   0.7803923\n",
      "   1.1369386 ]\n",
      " [-0.7787783   0.39126673  0.71887773 ...  1.251882    0.886168\n",
      "   1.3173451 ]\n",
      " [-0.95908964  0.6346029   0.87671334 ...  0.98183745  0.7440111\n",
      "   1.2903278 ]\n",
      " ...\n",
      " [-1.0732255   0.67236906  0.92303115 ...  0.9075458   0.8176712\n",
      "   1.3239655 ]\n",
      " [-1.1654117   0.68199664  0.6939452  ...  1.2238352   0.8028294\n",
      "   1.4506506 ]\n",
      " [-1.2732091   0.71458054  0.7581958  ...  0.9415482   0.8774844\n",
      "   1.2623048 ]]\n",
      "---\n",
      "[[-0.7019424   0.56254166  0.6880345  ...  1.1237322   0.78039217\n",
      "   1.1369387 ]\n",
      " [-0.778778    0.39126638  0.7188779  ...  1.2518823   0.8861681\n",
      "   1.3173454 ]\n",
      " [-0.9590891   0.6346026   0.87671363 ...  0.9818373   0.74401116\n",
      "   1.2903274 ]\n",
      " ...\n",
      " [-1.0732253   0.6723689   0.9230311  ...  0.9075457   0.8176713\n",
      "   1.3239657 ]\n",
      " [-1.165412    0.6819976   0.69394535 ...  1.2238353   0.80282927\n",
      "   1.4506509 ]\n",
      " [-1.2732087   0.71458083  0.7581961  ...  0.9415482   0.877484\n",
      "   1.2623053 ]]\n",
      "False\n",
      "True\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "encoder_out, mask = model.encoder(feat, feat_len)\n",
    "print(encoder_out.numpy()[0])\n",
    "print(\"---\")\n",
    "print(torch_encoder_out[0])\n",
    "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n",
    "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5))\n",
    "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "associate-sampling",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}