{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "choice-grade", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x\n" ] }, { "data": { "text/plain": [ "'/workspace/DeepSpeech-2.x'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%cd ..\n", "%pwd" ] }, { "cell_type": "code", "execution_count": 2, "id": "broke-broad", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " def convert_to_list(value, n, name, dtype=np.int):\n", "register user softmax to paddle, remove this when fixed!\n", "register user log_softmax to paddle, remove this when fixed!\n", "register user sigmoid to paddle, remove this when fixed!\n", "register user log_sigmoid to paddle, remove this when fixed!\n", "register user relu to paddle, remove this when fixed!\n", "override cat of paddle if exists or register, remove this when fixed!\n", "override item of paddle.Tensor if exists or register, remove this when fixed!\n", "override long of paddle.Tensor if exists or register, remove this when fixed!\n", "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", "override eq of paddle.Tensor if exists or register, remove this when fixed!\n", "override eq of paddle if exists or register, remove this when fixed!\n", "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", "register user view to paddle.Tensor, remove this when fixed!\n", "register user view_as to paddle.Tensor, remove this when fixed!\n", "register user masked_fill to paddle.Tensor, remove this when fixed!\n", "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", "register user fill_ to paddle.Tensor, remove this when fixed!\n", "register user repeat to paddle.Tensor, remove this when fixed!\n", "register user softmax to paddle.Tensor, remove this when fixed!\n", "register user sigmoid to paddle.Tensor, remove this when fixed!\n", "register user relu to paddle.Tensor, remove this when fixed!\n", "register user type_as to paddle.Tensor, remove this when fixed!\n", "register user to to paddle.Tensor, remove this when fixed!\n", "register user float to paddle.Tensor, remove this when fixed!\n", "register user tolist to paddle.Tensor, remove this when fixed!\n", "register user glu to paddle.nn.functional, remove this when fixed!\n", "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", "register user Module to paddle.nn, remove this when fixed!\n", "register user ModuleList to paddle.nn, remove this when fixed!\n", "register user GLU to paddle.nn, remove this when fixed!\n", "register user ConstantPad2d to paddle.nn, remove this when fixed!\n", "register user export to paddle.jit, remove this when fixed!\n" ] } ], "source": [ "import numpy as np\n", "import paddle\n", "from yacs.config import CfgNode as CN\n", "\n", "from deepspeech.models.u2 import U2Model\n", "from deepspeech.utils.layer_tools import print_params\n", "from deepspeech.utils.layer_tools import summary" ] }, { "cell_type": "code", "execution_count": 3, "id": "permanent-summary", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n", "[INFO 2021/04/20 03:32:21 u2.py:834] U2 Encoder type: conformer\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", "encoder.embed.conv.0.bias | [256] | 256 | True\n", "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", "encoder.embed.conv.2.bias | [256] | 256 | True\n", "encoder.embed.out.0.weight | [4864, 256] | 1245184 | True\n", "encoder.embed.out.0.bias | [256] | 256 | True\n", "encoder.after_norm.weight | [256] | 256 | True\n", "encoder.after_norm.bias | [256] | 256 | True\n", "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.0.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.0.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.0.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.0.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.0.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.0.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.0.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.0.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.0.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.0.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.0.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.0.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.1.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.1.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.1.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.1.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.1.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.1.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.1.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.1.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.1.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.1.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.1.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.1.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.2.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.2.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.2.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.2.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.2.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.2.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.2.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.2.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.2.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.2.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.2.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.2.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.3.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.3.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.3.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.3.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.3.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.3.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.3.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.3.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.3.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.3.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.3.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.3.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.4.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.4.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.4.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.4.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.4.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.4.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.4.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.4.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.4.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.4.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.4.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.4.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.5.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.5.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.5.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.5.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.5.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.5.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.5.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.5.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.5.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.5.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.5.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.5.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.6.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.6.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.6.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.6.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.6.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.6.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.6.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.6.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.6.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.6.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.6.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.6.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.7.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.7.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.7.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.7.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.7.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.7.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.7.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.7.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.7.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.7.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.7.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.7.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.8.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.8.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.8.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.8.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.8.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.8.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.8.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.8.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.8.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.8.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.8.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.8.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.9.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.9.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.9.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.9.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.9.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.9.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.9.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.9.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.9.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.9.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.9.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.9.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.10.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.10.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.10.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.10.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.10.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.10.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.10.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.10.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.10.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.10.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.10.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.10.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n", "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256 | True\n", "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256 | True\n", "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n", "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n", "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n", "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n", "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n", "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256 | True\n", "encoder.encoders.11.conv_module.norm.weight | [256] | 256 | True\n", "encoder.encoders.11.conv_module.norm.bias | [256] | 256 | True\n", "encoder.encoders.11.conv_module.norm._mean | [256] | 256 | False\n", "encoder.encoders.11.conv_module.norm._variance | [256] | 256 | False\n", "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", "encoder.encoders.11.norm_ff.weight | [256] | 256 | True\n", "encoder.encoders.11.norm_ff.bias | [256] | 256 | True\n", "encoder.encoders.11.norm_mha.weight | [256] | 256 | True\n", "encoder.encoders.11.norm_mha.bias | [256] | 256 | True\n", "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256 | True\n", "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256 | True\n", "encoder.encoders.11.norm_conv.weight | [256] | 256 | True\n", "encoder.encoders.11.norm_conv.bias | [256] | 256 | True\n", "encoder.encoders.11.norm_final.weight | [256] | 256 | True\n", "encoder.encoders.11.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", "decoder.after_norm.weight | [256] | 256 | True\n", "decoder.after_norm.bias | [256] | 256 | True\n", "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", "decoder.output_layer.bias | [4233] | 4233 | True\n", "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", "decoder.decoders.0.norm1.weight | [256] | 256 | True\n", "decoder.decoders.0.norm1.bias | [256] | 256 | True\n", "decoder.decoders.0.norm2.weight | [256] | 256 | True\n", "decoder.decoders.0.norm2.bias | [256] | 256 | True\n", "decoder.decoders.0.norm3.weight | [256] | 256 | True\n", "decoder.decoders.0.norm3.bias | [256] | 256 | True\n", "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n", "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n", "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n", "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n", "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", "decoder.decoders.1.norm1.weight | [256] | 256 | True\n", "decoder.decoders.1.norm1.bias | [256] | 256 | True\n", "decoder.decoders.1.norm2.weight | [256] | 256 | True\n", "decoder.decoders.1.norm2.bias | [256] | 256 | True\n", "decoder.decoders.1.norm3.weight | [256] | 256 | True\n", "decoder.decoders.1.norm3.bias | [256] | 256 | True\n", "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n", "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n", "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n", "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n", "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", "decoder.decoders.2.norm1.weight | [256] | 256 | True\n", "decoder.decoders.2.norm1.bias | [256] | 256 | True\n", "decoder.decoders.2.norm2.weight | [256] | 256 | True\n", "decoder.decoders.2.norm2.bias | [256] | 256 | True\n", "decoder.decoders.2.norm3.weight | [256] | 256 | True\n", "decoder.decoders.2.norm3.bias | [256] | 256 | True\n", "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n", "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n", "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n", "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n", "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", "decoder.decoders.3.norm1.weight | [256] | 256 | True\n", "decoder.decoders.3.norm1.bias | [256] | 256 | True\n", "decoder.decoders.3.norm2.weight | [256] | 256 | True\n", "decoder.decoders.3.norm2.bias | [256] | 256 | True\n", "decoder.decoders.3.norm3.weight | [256] | 256 | True\n", "decoder.decoders.3.norm3.bias | [256] | 256 | True\n", "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n", "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n", "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n", "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n", "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", "decoder.decoders.4.norm1.weight | [256] | 256 | True\n", "decoder.decoders.4.norm1.bias | [256] | 256 | True\n", "decoder.decoders.4.norm2.weight | [256] | 256 | True\n", "decoder.decoders.4.norm2.bias | [256] | 256 | True\n", "decoder.decoders.4.norm3.weight | [256] | 256 | True\n", "decoder.decoders.4.norm3.bias | [256] | 256 | True\n", "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n", "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n", "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n", "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n", "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n", "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n", "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n", "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", "decoder.decoders.5.norm1.weight | [256] | 256 | True\n", "decoder.decoders.5.norm1.bias | [256] | 256 | True\n", "decoder.decoders.5.norm2.weight | [256] | 256 | True\n", "decoder.decoders.5.norm2.bias | [256] | 256 | True\n", "decoder.decoders.5.norm3.weight | [256] | 256 | True\n", "decoder.decoders.5.norm3.bias | [256] | 256 | True\n", "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n", "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", "ctc.ctc_lo.bias | [4233] | 4233 | True\n", "Total parameters: 687.0, 49355282.0 elements.\n" ] } ], "source": [ "conf_str='examples/aishell/s1/conf/conformer.yaml'\n", "cfg = CN().load_cfg(open(conf_str))\n", "cfg.model.input_dim = 80\n", "cfg.model.output_dim = 4233\n", "cfg.model.cmvn_file = \"/workspace/wenet/examples/aishell/s0/raw_wav/train/global_cmvn\"\n", "cfg.model.cmvn_file_type = 'json'\n", "cfg.freeze()\n", "\n", "model = U2Model(cfg.model)\n", "print_params(model)\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "sapphire-agent", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "encoder.global_cmvn.mean | [80] | 80\n", "encoder.global_cmvn.istd | [80] | 80\n", "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304\n", "encoder.embed.conv.0.bias | [256] | 256\n", "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824\n", "encoder.embed.conv.2.bias | [256] | 256\n", "encoder.embed.out.0.weight | [4864, 256] | 1245184\n", "encoder.embed.out.0.bias | [256] | 256\n", "encoder.after_norm.weight | [256] | 256\n", "encoder.after_norm.bias | [256] | 256\n", "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.0.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.0.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.0.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.0.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.0.norm_ff.weight | [256] | 256\n", "encoder.encoders.0.norm_ff.bias | [256] | 256\n", "encoder.encoders.0.norm_mha.weight | [256] | 256\n", "encoder.encoders.0.norm_mha.bias | [256] | 256\n", "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.0.norm_conv.weight | [256] | 256\n", "encoder.encoders.0.norm_conv.bias | [256] | 256\n", "encoder.encoders.0.norm_final.weight | [256] | 256\n", "encoder.encoders.0.norm_final.bias | [256] | 256\n", "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.0.concat_linear.bias | [256] | 256\n", "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.1.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.1.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.1.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.1.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.1.norm_ff.weight | [256] | 256\n", "encoder.encoders.1.norm_ff.bias | [256] | 256\n", "encoder.encoders.1.norm_mha.weight | [256] | 256\n", "encoder.encoders.1.norm_mha.bias | [256] | 256\n", "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.1.norm_conv.weight | [256] | 256\n", "encoder.encoders.1.norm_conv.bias | [256] | 256\n", "encoder.encoders.1.norm_final.weight | [256] | 256\n", "encoder.encoders.1.norm_final.bias | [256] | 256\n", "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.1.concat_linear.bias | [256] | 256\n", "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.2.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.2.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.2.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.2.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.2.norm_ff.weight | [256] | 256\n", "encoder.encoders.2.norm_ff.bias | [256] | 256\n", "encoder.encoders.2.norm_mha.weight | [256] | 256\n", "encoder.encoders.2.norm_mha.bias | [256] | 256\n", "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.2.norm_conv.weight | [256] | 256\n", "encoder.encoders.2.norm_conv.bias | [256] | 256\n", "encoder.encoders.2.norm_final.weight | [256] | 256\n", "encoder.encoders.2.norm_final.bias | [256] | 256\n", "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.2.concat_linear.bias | [256] | 256\n", "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.3.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.3.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.3.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.3.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.3.norm_ff.weight | [256] | 256\n", "encoder.encoders.3.norm_ff.bias | [256] | 256\n", "encoder.encoders.3.norm_mha.weight | [256] | 256\n", "encoder.encoders.3.norm_mha.bias | [256] | 256\n", "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.3.norm_conv.weight | [256] | 256\n", "encoder.encoders.3.norm_conv.bias | [256] | 256\n", "encoder.encoders.3.norm_final.weight | [256] | 256\n", "encoder.encoders.3.norm_final.bias | [256] | 256\n", "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.3.concat_linear.bias | [256] | 256\n", "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.4.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.4.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.4.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.4.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.4.norm_ff.weight | [256] | 256\n", "encoder.encoders.4.norm_ff.bias | [256] | 256\n", "encoder.encoders.4.norm_mha.weight | [256] | 256\n", "encoder.encoders.4.norm_mha.bias | [256] | 256\n", "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.4.norm_conv.weight | [256] | 256\n", "encoder.encoders.4.norm_conv.bias | [256] | 256\n", "encoder.encoders.4.norm_final.weight | [256] | 256\n", "encoder.encoders.4.norm_final.bias | [256] | 256\n", "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.4.concat_linear.bias | [256] | 256\n", "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.5.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.5.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.5.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.5.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.5.norm_ff.weight | [256] | 256\n", "encoder.encoders.5.norm_ff.bias | [256] | 256\n", "encoder.encoders.5.norm_mha.weight | [256] | 256\n", "encoder.encoders.5.norm_mha.bias | [256] | 256\n", "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.5.norm_conv.weight | [256] | 256\n", "encoder.encoders.5.norm_conv.bias | [256] | 256\n", "encoder.encoders.5.norm_final.weight | [256] | 256\n", "encoder.encoders.5.norm_final.bias | [256] | 256\n", "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.5.concat_linear.bias | [256] | 256\n", "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.6.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.6.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.6.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.6.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.6.norm_ff.weight | [256] | 256\n", "encoder.encoders.6.norm_ff.bias | [256] | 256\n", "encoder.encoders.6.norm_mha.weight | [256] | 256\n", "encoder.encoders.6.norm_mha.bias | [256] | 256\n", "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.6.norm_conv.weight | [256] | 256\n", "encoder.encoders.6.norm_conv.bias | [256] | 256\n", "encoder.encoders.6.norm_final.weight | [256] | 256\n", "encoder.encoders.6.norm_final.bias | [256] | 256\n", "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.6.concat_linear.bias | [256] | 256\n", "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.7.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.7.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.7.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.7.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.7.norm_ff.weight | [256] | 256\n", "encoder.encoders.7.norm_ff.bias | [256] | 256\n", "encoder.encoders.7.norm_mha.weight | [256] | 256\n", "encoder.encoders.7.norm_mha.bias | [256] | 256\n", "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.7.norm_conv.weight | [256] | 256\n", "encoder.encoders.7.norm_conv.bias | [256] | 256\n", "encoder.encoders.7.norm_final.weight | [256] | 256\n", "encoder.encoders.7.norm_final.bias | [256] | 256\n", "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.7.concat_linear.bias | [256] | 256\n", "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.8.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.8.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.8.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.8.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.8.norm_ff.weight | [256] | 256\n", "encoder.encoders.8.norm_ff.bias | [256] | 256\n", "encoder.encoders.8.norm_mha.weight | [256] | 256\n", "encoder.encoders.8.norm_mha.bias | [256] | 256\n", "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.8.norm_conv.weight | [256] | 256\n", "encoder.encoders.8.norm_conv.bias | [256] | 256\n", "encoder.encoders.8.norm_final.weight | [256] | 256\n", "encoder.encoders.8.norm_final.bias | [256] | 256\n", "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.8.concat_linear.bias | [256] | 256\n", "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.9.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.9.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.9.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.9.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.9.norm_ff.weight | [256] | 256\n", "encoder.encoders.9.norm_ff.bias | [256] | 256\n", "encoder.encoders.9.norm_mha.weight | [256] | 256\n", "encoder.encoders.9.norm_mha.bias | [256] | 256\n", "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.9.norm_conv.weight | [256] | 256\n", "encoder.encoders.9.norm_conv.bias | [256] | 256\n", "encoder.encoders.9.norm_final.weight | [256] | 256\n", "encoder.encoders.9.norm_final.bias | [256] | 256\n", "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.9.concat_linear.bias | [256] | 256\n", "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.10.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.10.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.10.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.10.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.10.norm_ff.weight | [256] | 256\n", "encoder.encoders.10.norm_ff.bias | [256] | 256\n", "encoder.encoders.10.norm_mha.weight | [256] | 256\n", "encoder.encoders.10.norm_mha.bias | [256] | 256\n", "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.10.norm_conv.weight | [256] | 256\n", "encoder.encoders.10.norm_conv.bias | [256] | 256\n", "encoder.encoders.10.norm_final.weight | [256] | 256\n", "encoder.encoders.10.norm_final.bias | [256] | 256\n", "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.10.concat_linear.bias | [256] | 256\n", "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256\n", "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256\n", "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536\n", "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256\n", "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536\n", "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256\n", "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536\n", "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256\n", "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536\n", "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256\n", "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536\n", "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048\n", "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256\n", "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048\n", "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256\n", "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512\n", "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256\n", "encoder.encoders.11.conv_module.norm.weight | [256] | 256\n", "encoder.encoders.11.conv_module.norm.bias | [256] | 256\n", "encoder.encoders.11.conv_module.norm._mean | [256] | 256\n", "encoder.encoders.11.conv_module.norm._variance | [256] | 256\n", "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256\n", "encoder.encoders.11.norm_ff.weight | [256] | 256\n", "encoder.encoders.11.norm_ff.bias | [256] | 256\n", "encoder.encoders.11.norm_mha.weight | [256] | 256\n", "encoder.encoders.11.norm_mha.bias | [256] | 256\n", "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256\n", "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256\n", "encoder.encoders.11.norm_conv.weight | [256] | 256\n", "encoder.encoders.11.norm_conv.bias | [256] | 256\n", "encoder.encoders.11.norm_final.weight | [256] | 256\n", "encoder.encoders.11.norm_final.bias | [256] | 256\n", "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.11.concat_linear.bias | [256] | 256\n", "decoder.embed.0.weight | [4233, 256] | 1083648\n", "decoder.after_norm.weight | [256] | 256\n", "decoder.after_norm.bias | [256] | 256\n", "decoder.output_layer.weight | [256, 4233] | 1083648\n", "decoder.output_layer.bias | [4233] | 4233\n", "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n", "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048\n", "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n", "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256\n", "decoder.decoders.0.norm1.weight | [256] | 256\n", "decoder.decoders.0.norm1.bias | [256] | 256\n", "decoder.decoders.0.norm2.weight | [256] | 256\n", "decoder.decoders.0.norm2.bias | [256] | 256\n", "decoder.decoders.0.norm3.weight | [256] | 256\n", "decoder.decoders.0.norm3.bias | [256] | 256\n", "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072\n", "decoder.decoders.0.concat_linear1.bias | [256] | 256\n", "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072\n", "decoder.decoders.0.concat_linear2.bias | [256] | 256\n", "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n", "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048\n", "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n", "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256\n", "decoder.decoders.1.norm1.weight | [256] | 256\n", "decoder.decoders.1.norm1.bias | [256] | 256\n", "decoder.decoders.1.norm2.weight | [256] | 256\n", "decoder.decoders.1.norm2.bias | [256] | 256\n", "decoder.decoders.1.norm3.weight | [256] | 256\n", "decoder.decoders.1.norm3.bias | [256] | 256\n", "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072\n", "decoder.decoders.1.concat_linear1.bias | [256] | 256\n", "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072\n", "decoder.decoders.1.concat_linear2.bias | [256] | 256\n", "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n", "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048\n", "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n", "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256\n", "decoder.decoders.2.norm1.weight | [256] | 256\n", "decoder.decoders.2.norm1.bias | [256] | 256\n", "decoder.decoders.2.norm2.weight | [256] | 256\n", "decoder.decoders.2.norm2.bias | [256] | 256\n", "decoder.decoders.2.norm3.weight | [256] | 256\n", "decoder.decoders.2.norm3.bias | [256] | 256\n", "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072\n", "decoder.decoders.2.concat_linear1.bias | [256] | 256\n", "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072\n", "decoder.decoders.2.concat_linear2.bias | [256] | 256\n", "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n", "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048\n", "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n", "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256\n", "decoder.decoders.3.norm1.weight | [256] | 256\n", "decoder.decoders.3.norm1.bias | [256] | 256\n", "decoder.decoders.3.norm2.weight | [256] | 256\n", "decoder.decoders.3.norm2.bias | [256] | 256\n", "decoder.decoders.3.norm3.weight | [256] | 256\n", "decoder.decoders.3.norm3.bias | [256] | 256\n", "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072\n", "decoder.decoders.3.concat_linear1.bias | [256] | 256\n", "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072\n", "decoder.decoders.3.concat_linear2.bias | [256] | 256\n", "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n", "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048\n", "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n", "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256\n", "decoder.decoders.4.norm1.weight | [256] | 256\n", "decoder.decoders.4.norm1.bias | [256] | 256\n", "decoder.decoders.4.norm2.weight | [256] | 256\n", "decoder.decoders.4.norm2.bias | [256] | 256\n", "decoder.decoders.4.norm3.weight | [256] | 256\n", "decoder.decoders.4.norm3.bias | [256] | 256\n", "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072\n", "decoder.decoders.4.concat_linear1.bias | [256] | 256\n", "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072\n", "decoder.decoders.4.concat_linear2.bias | [256] | 256\n", "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256\n", "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536\n", "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256\n", "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n", "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048\n", "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n", "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256\n", "decoder.decoders.5.norm1.weight | [256] | 256\n", "decoder.decoders.5.norm1.bias | [256] | 256\n", "decoder.decoders.5.norm2.weight | [256] | 256\n", "decoder.decoders.5.norm2.bias | [256] | 256\n", "decoder.decoders.5.norm3.weight | [256] | 256\n", "decoder.decoders.5.norm3.bias | [256] | 256\n", "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072\n", "decoder.decoders.5.concat_linear1.bias | [256] | 256\n", "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072\n", "decoder.decoders.5.concat_linear2.bias | [256] | 256\n", "ctc.ctc_lo.weight | [256, 4233] | 1083648\n", "ctc.ctc_lo.bias | [4233] | 4233\n", "Total parameters: 689, 49355442 elements.\n" ] } ], "source": [ "summary(model)" ] }, { "cell_type": "code", "execution_count": 5, "id": "ruled-invitation", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "U2Model(\n", " (encoder): ConformerEncoder(\n", " (global_cmvn): GlobalCMVN()\n", " (embed): Conv2dSubsampling4(\n", " (pos_enc): RelPositionalEncoding(\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " )\n", " (conv): Sequential(\n", " (0): Conv2D(1, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n", " (1): ReLU()\n", " (2): Conv2D(256, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n", " (3): ReLU()\n", " )\n", " (out): Sequential(\n", " (0): Linear(in_features=4864, out_features=256, dtype=float32)\n", " )\n", " )\n", " (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (encoders): LayerList(\n", " (0): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (1): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (2): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (3): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (4): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (5): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (6): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (7): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (8): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (9): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (10): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (11): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " )\n", " )\n", " (decoder): TransformerDecoder(\n", " (embed): Sequential(\n", " (0): Embedding(4233, 256, sparse=False)\n", " (1): PositionalEncoding(\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " )\n", " )\n", " (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (output_layer): Linear(in_features=256, out_features=4233, dtype=float32)\n", " (decoders): LayerList(\n", " (0): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (1): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (2): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (3): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (4): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " (5): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", " )\n", " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", " )\n", " )\n", " )\n", " (ctc): CTCDecoder(\n", " (ctc_lo): Linear(in_features=256, out_features=4233, dtype=float32)\n", " (criterion): CTCLoss(\n", " (loss): CTCLoss()\n", " )\n", " )\n", " (criterion_att): LabelSmoothingLoss(\n", " (criterion): KLDivLoss()\n", " )\n", ")\n" ] } ], "source": [ "print(model)" ] }, { "cell_type": "code", "execution_count": 6, "id": "fossil-means", "metadata": {}, "outputs": [], "source": [ "# load feat" ] }, { "cell_type": "code", "execution_count": 7, "id": "fleet-despite", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "compute_cmvn_loader_test.ipynb encoder.npz\r\n", "dataloader.ipynb hack_api_test.ipynb\r\n", "dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n", "data.npz layer_norm_test.ipynb\r\n", "decoder.npz Linear_test.ipynb\r\n", "enc_0_ff_out.npz mask_and_masked_fill_test.ipynb\r\n", "enc_0_norm_ff.npz model.npz\r\n", "enc_0.npz position_embeding_check.ipynb\r\n", "enc_0_selattn_out.npz python_test.ipynb\r\n", "enc_2.npz train_test.ipynb\r\n", "enc_all.npz u2_model.ipynb\r\n", "enc_embed.npz\r\n" ] } ], "source": [ "%ls .notebook" ] }, { "cell_type": "code", "execution_count": 8, "id": "abroad-oracle", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['BAC009S0739W0246' 'BAC009S0727W0424' 'BAC009S0753W0412'\n", " 'BAC009S0756W0206' 'BAC009S0740W0414' 'BAC009S0728W0426'\n", " 'BAC009S0739W0214' 'BAC009S0753W0423' 'BAC009S0734W0201'\n", " 'BAC009S0740W0427' 'BAC009S0730W0423' 'BAC009S0728W0367'\n", " 'BAC009S0730W0418' 'BAC009S0727W0157' 'BAC009S0749W0409'\n", " 'BAC009S0727W0418']\n", "(16, 207, 80)\n", "[[[ 8.994624 9.538309 9.191589 ... 10.507416 9.563305 8.256403 ]\n", " [ 9.798841 10.405224 9.26511 ... 10.251211 9.543982 8.873768 ]\n", " [10.6890745 10.395469 8.053548 ... 9.906749 10.064903 8.050915 ]\n", " ...\n", " [ 9.217986 9.65069 8.505259 ... 9.687183 8.742463 7.9865475]\n", " [10.129122 9.935194 9.37982 ... 9.563894 9.825992 8.979543 ]\n", " [ 9.095531 7.1338377 9.468001 ... 9.472748 9.021235 7.447914 ]]\n", "\n", " [[11.430976 10.671858 6.0841026 ... 9.382682 8.729745 7.5315614]\n", " [ 9.731717 7.8104815 7.5714607 ... 10.043035 9.243595 7.3540792]\n", " [10.65017 10.600604 8.467784 ... 9.281448 9.186885 8.070343 ]\n", " ...\n", " [ 9.096987 9.2637 8.075275 ... 8.431845 8.370505 8.002926 ]\n", " [10.461651 10.147784 6.7693496 ... 9.779426 9.577453 8.080652 ]\n", " [ 7.794432 5.621059 7.9750648 ... 9.997245 9.849678 8.031287 ]]\n", "\n", " [[ 7.3455667 7.896357 7.5795946 ... 11.631024 10.451254 9.123633 ]\n", " [ 8.628678 8.4630575 7.499242 ... 12.415986 10.975749 8.9425745]\n", " [ 9.831394 10.2812805 8.97241 ... 12.1386795 10.40175 9.005517 ]\n", " ...\n", " [ 7.089641 7.405548 6.8142557 ... 9.325196 9.273162 8.353427 ]\n", " [ 0. 0. 0. ... 0. 0. 0. ]\n", " [ 0. 0. 0. ... 0. 0. 0. ]]\n", "\n", " ...\n", "\n", " [[10.933237 10.464394 7.7202725 ... 10.348816 9.302338 7.1553144]\n", " [10.449866 9.907033 9.029272 ... 9.952465 9.414051 7.559279 ]\n", " [10.487655 9.81259 9.895244 ... 9.58662 9.341254 7.7849016]\n", " ...\n", " [ 0. 0. 0. ... 0. 0. 0. ]\n", " [ 0. 0. 0. ... 0. 0. 0. ]\n", " [ 0. 0. 0. ... 0. 0. 0. ]]\n", "\n", " [[ 9.944384 9.585867 8.220328 ... 11.588647 11.045029 8.817075 ]\n", " [ 7.678356 8.322397 7.533047 ... 11.055085 10.535685 9.27465 ]\n", " [ 8.626197 9.675917 9.841045 ... 11.378827 10.922112 8.991444 ]\n", " ...\n", " [ 0. 0. 0. ... 0. 0. 0. ]\n", " [ 0. 0. 0. ... 0. 0. 0. ]\n", " [ 0. 0. 0. ... 0. 0. 0. ]]\n", "\n", " [[ 8.107938 7.759043 6.710301 ... 12.650573 11.466156 11.061517 ]\n", " [11.380332 11.222007 8.658889 ... 12.810616 12.222216 11.689288 ]\n", " [10.677676 9.920579 8.046089 ... 13.572894 12.5624075 11.155033 ]\n", " ...\n", " [ 0. 0. 0. ... 0. 0. 0. ]\n", " [ 0. 0. 0. ... 0. 0. 0. ]\n", " [ 0. 0. 0. ... 0. 0. 0. ]]]\n", "[207 207 205 205 203 203 198 197 195 188 186 186 185 180 166 163]\n", "[[2995 3116 1209 565 -1 -1]\n", " [ 236 1176 331 66 3925 4077]\n", " [2693 524 234 1145 366 -1]\n", " [3875 4211 3062 700 -1 -1]\n", " [ 272 987 1134 494 2959 -1]\n", " [1936 3715 120 2553 2695 2710]\n", " [ 25 1149 3930 -1 -1 -1]\n", " [1753 1778 1237 482 3925 110]\n", " [3703 2 565 3827 -1 -1]\n", " [1150 2734 10 2478 3490 -1]\n", " [ 426 811 95 489 144 -1]\n", " [2313 2006 489 975 -1 -1]\n", " [3702 3414 205 1488 2966 1347]\n", " [ 70 1741 702 1666 -1 -1]\n", " [ 703 1778 1030 849 -1 -1]\n", " [ 814 1674 115 3827 -1 -1]]\n", "[4 6 5 4 5 6 3 6 4 5 5 4 6 4 4 4]\n" ] } ], "source": [ "data = np.load('.notebook/data.npz', allow_pickle=True)\n", "keys=data['keys']\n", "feat=data['feat']\n", "feat_len=data['feat_len']\n", "text=data['text']\n", "text_len=data['text_len']\n", "print(keys)\n", "print(feat.shape)\n", "print(feat)\n", "print(feat_len)\n", "print(text)\n", "print(text_len)" ] }, { "cell_type": "code", "execution_count": null, "id": "false-instrument", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 9, "id": "arctic-proxy", "metadata": {}, "outputs": [], "source": [ "# ['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", "# torch.Size([16, 207, 80])\n", "# tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", "# [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", "# [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", "# ...,\n", "# [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", "# [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", "# [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", "\n", "# [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", "# [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", "# [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", "# ...,\n", "# [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", "# [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", "# [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", "\n", "# [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", "# [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", "# [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", "# ...,\n", "# [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", "# ...,\n", "\n", "# [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", "# [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", "# [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", "# ...,\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", "# [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", "# [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", "# [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", "# ...,\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", "# [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", "# [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", "# [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", "# ...,\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", "# tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", "# 166, 163], dtype=torch.int32)\n", "# tensor([[2995, 3116, 1209, 565, -1, -1],\n", "# [ 236, 1176, 331, 66, 3925, 4077],\n", "# [2693, 524, 234, 1145, 366, -1],\n", "# [3875, 4211, 3062, 700, -1, -1],\n", "# [ 272, 987, 1134, 494, 2959, -1],\n", "# [1936, 3715, 120, 2553, 2695, 2710],\n", "# [ 25, 1149, 3930, -1, -1, -1],\n", "# [1753, 1778, 1237, 482, 3925, 110],\n", "# [3703, 2, 565, 3827, -1, -1],\n", "# [1150, 2734, 10, 2478, 3490, -1],\n", "# [ 426, 811, 95, 489, 144, -1],\n", "# [2313, 2006, 489, 975, -1, -1],\n", "# [3702, 3414, 205, 1488, 2966, 1347],\n", "# [ 70, 1741, 702, 1666, -1, -1],\n", "# [ 703, 1778, 1030, 849, -1, -1],\n", "# [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", "# tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)" ] }, { "cell_type": "code", "execution_count": null, "id": "seasonal-switch", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "id": "defined-brooks", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "compute_cmvn_loader_test.ipynb\t encoder.npz\r\n", "dataloader.ipynb\t\t hack_api_test.ipynb\r\n", "dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n", "data.npz\t\t\t layer_norm_test.ipynb\r\n", "decoder.npz\t\t\t Linear_test.ipynb\r\n", "enc_0_ff_out.npz\t\t mask_and_masked_fill_test.ipynb\r\n", "enc_0_norm_ff.npz\t\t model.npz\r\n", "enc_0.npz\t\t\t position_embeding_check.ipynb\r\n", "enc_0_selattn_out.npz\t\t python_test.ipynb\r\n", "enc_2.npz\t\t\t train_test.ipynb\r\n", "enc_all.npz\t\t\t u2_model.ipynb\r\n", "enc_embed.npz\r\n" ] } ], "source": [ "# load model param\n", "!ls .notebook\n", "data = np.load('.notebook/model.npz', allow_pickle=True)\n", "state_dict = data['state'].item()\n", "\n", "for key, _ in model.state_dict().items():\n", " if key not in state_dict:\n", " print(f\"{key} not find.\")\n", "\n", "model.set_state_dict(state_dict)\n", "\n", "now_state_dict = model.state_dict()\n", "for key, value in now_state_dict.items():\n", " if not np.allclose(value.numpy(), state_dict[key]):\n", " print(key)" ] }, { "cell_type": "code", "execution_count": null, "id": "exempt-viewer", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "id": "confident-piano", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/framework.py:687: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " elif dtype == np.bool:\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [142.48880005]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [41.84146118]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [377.33258057])\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:238: UserWarning: The dtype of left and right variables are not the same, left dtype is VarType.FP32, but right dtype is VarType.INT32, the right dtype will convert to VarType.FP32\n", " format(lhs_dtype, rhs_dtype, lhs_dtype))\n" ] } ], "source": [ "# compute loss\n", "import paddle\n", "feat=paddle.to_tensor(feat)\n", "feat_len=paddle.to_tensor(feat_len, dtype='int64')\n", "text=paddle.to_tensor(text, dtype='int64')\n", "text_len=paddle.to_tensor(text_len, dtype='int64')\n", "\n", "model.eval()\n", "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", " text, text_len)\n", "print(total_loss, attention_loss, ctc_loss )" ] }, { "cell_type": "code", "execution_count": 12, "id": "better-senator", "metadata": {}, "outputs": [], "source": [ "# tensor(142.4888, device='cuda:0', grad_fn=) \n", "# tensor(41.8415, device='cuda:0', grad_fn=) \n", "# tensor(377.3326, device='cuda:0', grad_fn=)\n", "# 142.4888 41.84146 377.33258" ] }, { "cell_type": "code", "execution_count": null, "id": "related-banking", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 13, "id": "olympic-problem", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[16, 51, 256]\n", "[16, 1, 51]\n", "Tensor(shape=[51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [[-0.70194179, 0.56254166, 0.68803459, ..., 1.12373221, 0.78039235, 1.13693869],\n", " [-0.77877808, 0.39126658, 0.71887815, ..., 1.25188220, 0.88616788, 1.31734526],\n", " [-0.95908946, 0.63460249, 0.87671334, ..., 0.98183727, 0.74401081, 1.29032660],\n", " ...,\n", " [-1.07322502, 0.67236906, 0.92303109, ..., 0.90754563, 0.81767166, 1.32396567],\n", " [-1.16541159, 0.68199694, 0.69394493, ..., 1.22383487, 0.80282891, 1.45065081],\n", " [-1.27320945, 0.71458030, 0.75819558, ..., 0.94154912, 0.87748396, 1.26230514]])\n" ] } ], "source": [ "# ecnoder\n", "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", "print(encoder_out.shape)\n", "print(encoder_mask.shape)\n", "print(encoder_out[0])" ] }, { "cell_type": "code", "execution_count": 14, "id": "shaped-alaska", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deepspeech examples README_cn.md\tsetup.sh tools\r\n", "docs\t LICENSE README.md\t\ttests\t utils\r\n", "env.sh\t log requirements.txt\tthird_party\r\n" ] } ], "source": [ "!ls\n", "data = np.load('.notebook/encoder.npz', allow_pickle=True)\n", "torch_mask = data['mask']\n", "torch_encoder_out = data['out']" ] }, { "cell_type": "code", "execution_count": 15, "id": "federal-rover", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "print(np.testing.assert_equal(torch_mask, encoder_mask.numpy()))" ] }, { "cell_type": "code", "execution_count": 16, "id": "regulated-interstate", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n", "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", " 1.1369387 ]\n", " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", " 1.3173454 ]\n", " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", " 1.2903274 ]\n", " ...\n", " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", " 1.3239657 ]\n", " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", " 1.4506509 ]\n", " [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n", " 1.2623053 ]]\n", "----\n", "[[-0.7019418 0.56254166 0.6880346 ... 1.1237322 0.78039235\n", " 1.1369387 ]\n", " [-0.7787781 0.39126658 0.71887815 ... 1.2518822 0.8861679\n", " 1.3173453 ]\n", " [-0.95908946 0.6346025 0.87671334 ... 0.9818373 0.7440108\n", " 1.2903266 ]\n", " ...\n", " [-1.073225 0.67236906 0.9230311 ... 0.9075456 0.81767166\n", " 1.3239657 ]\n", " [-1.1654116 0.68199694 0.69394493 ... 1.2238349 0.8028289\n", " 1.4506508 ]\n", " [-1.2732095 0.7145803 0.7581956 ... 0.9415491 0.87748396\n", " 1.2623051 ]]\n", "True\n", "False\n" ] } ], "source": [ "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", "print(torch_encoder_out[0])\n", "print(\"----\")\n", "print(encoder_out.numpy()[0])\n", "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5, rtol=1e-6))\n", "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6, rtol=1e-6))" ] }, { "cell_type": "code", "execution_count": 17, "id": "proof-scheduling", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [377.33258057])\n", "[1.]\n", "[[ 3.16902876e+00 -1.51763987e-02 4.91095744e-02 ... -2.47971853e-03\n", " -5.93360700e-03 -7.26609165e-03]\n", " [-1.74184477e+00 7.75874173e-03 -4.49434854e-02 ... 9.92412097e-04\n", " 2.46337592e-03 2.31892057e-03]\n", " [-2.33343339e+00 1.30475955e-02 -2.66557075e-02 ... 2.27532350e-03\n", " 5.76924905e-03 7.48788286e-03]\n", " ...\n", " [-4.30358458e+00 2.46054661e-02 -9.00950655e-02 ... 4.43156436e-03\n", " 1.16122244e-02 1.44715561e-02]\n", " [-3.36921120e+00 1.73153952e-02 -6.36872873e-02 ... 3.28363618e-03\n", " 8.58010259e-03 1.07794888e-02]\n", " [-6.62045336e+00 3.49955931e-02 -1.23962618e-01 ... 6.36671018e-03\n", " 1.60814095e-02 2.03891303e-02]]\n", "[-4.3777819e+00 2.3245810e-02 -9.3339294e-02 ... 4.2569344e-03\n", " 1.0919910e-02 1.3787797e-02]\n" ] } ], "source": [ "from paddle.nn import functional as F\n", "def ctc_loss(logits,\n", " labels,\n", " input_lengths,\n", " label_lengths,\n", " blank=0,\n", " reduction='mean',\n", " norm_by_times=False):\n", " loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,\n", " input_lengths, label_lengths)\n", " loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])\n", " assert reduction in ['mean', 'sum', 'none']\n", " if reduction == 'mean':\n", " loss_out = paddle.mean(loss_out / label_lengths)\n", " elif reduction == 'sum':\n", " loss_out = paddle.sum(loss_out)\n", " return loss_out\n", "\n", "F.ctc_loss = ctc_loss\n", "\n", "torch_mask_t = paddle.to_tensor(torch_mask, dtype='int64')\n", "encoder_out_lens = torch_mask_t.squeeze(1).sum(1)\n", "loss_ctc = model.ctc(paddle.to_tensor(torch_encoder_out), encoder_out_lens, text, text_len)\n", "print(loss_ctc)\n", "loss_ctc.backward()\n", "print(loss_ctc.grad)\n", "print(model.ctc.ctc_lo.weight.grad)\n", "print(model.ctc.ctc_lo.bias.grad)\n", "\n", "\n", "# tensor(377.3326, device='cuda:0', grad_fn=)\n", "# None\n", "# [[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", "# -5.93366381e-03 -7.26613170e-03]\n", "# [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", "# 2.46338220e-03 2.31891591e-03]\n", "# [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", "# 5.76929189e-03 7.48792710e-03]\n", "# ...\n", "# [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", "# 1.16123557e-02 1.44716976e-02]\n", "# [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", "# 8.58021621e-03 1.07796099e-02]\n", "# [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", "# 1.60815325e-02 2.03892551e-02]]\n", "# [-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", "# 1.0920014e-02 1.3787906e-02]" ] }, { "cell_type": "code", "execution_count": null, "id": "enclosed-consolidation", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 18, "id": "synthetic-hungarian", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [41.84146118]) 0.0\n" ] } ], "source": [ "loss_att, acc_att = model._calc_att_loss(paddle.to_tensor(torch_encoder_out), paddle.to_tensor(torch_mask),\n", " text, text_len)\n", "print(loss_att, acc_att)\n", "#tensor(41.8416, device='cuda:0', grad_fn=) 0.0" ] }, { "cell_type": "code", "execution_count": 19, "id": "indian-sweden", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 202, "id": "marine-cuisine", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n", " 1.5034772e-02 4.0337229e-01]\n", " [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n", " -1.4352810e-01 -1.0023664e+00]\n", " [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n", " -1.1672243e+00 -2.6848501e-01]\n", " ...\n", " [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n", " 4.5470044e-02 -3.7139410e-01]\n", " [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n", " -7.9347193e-04 4.2537671e-01]\n", " [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n", " -1.0499060e-03 4.2678756e-01]]\n" ] } ], "source": [ "data = np.load(\".notebook/decoder.npz\", allow_pickle=True)\n", "torch_decoder_out = data['decoder_out']\n", "print(torch_decoder_out[0])" ] }, { "cell_type": "code", "execution_count": 180, "id": "several-result", "metadata": {}, "outputs": [], "source": [ "def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,\n", " ignore_id: int):\n", " \"\"\"Add and labels.\n", " Args:\n", " ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)\n", " sos (int): index of \n", " eos (int): index of \n", " ignore_id (int): index of padding\n", " Returns:\n", " ys_in (paddle.Tensor) : (B, Lmax + 1)\n", " ys_out (paddle.Tensor) : (B, Lmax + 1)\n", " Examples:\n", " >>> sos_id = 10\n", " >>> eos_id = 11\n", " >>> ignore_id = -1\n", " >>> ys_pad\n", " tensor([[ 1, 2, 3, 4, 5],\n", " [ 4, 5, 6, -1, -1],\n", " [ 7, 8, 9, -1, -1]], dtype=paddle.int32)\n", " >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)\n", " >>> ys_in\n", " tensor([[10, 1, 2, 3, 4, 5],\n", " [10, 4, 5, 6, 11, 11],\n", " [10, 7, 8, 9, 11, 11]])\n", " >>> ys_out\n", " tensor([[ 1, 2, 3, 4, 5, 11],\n", " [ 4, 5, 6, 11, -1, -1],\n", " [ 7, 8, 9, 11, -1, -1]])\n", " \"\"\"\n", " # TODO(Hui Zhang): using comment code, \n", " #_sos = paddle.to_tensor(\n", " # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", " #_eos = paddle.to_tensor(\n", " # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", " #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", " #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]\n", " #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]\n", " #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)\n", " B = ys_pad.size(0)\n", " _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos\n", " _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos\n", " ys_in = paddle.cat([_sos, ys_pad], dim=1)\n", " mask_pad = (ys_in == ignore_id)\n", " ys_in = ys_in.masked_fill(mask_pad, eos)\n", " \n", "\n", " ys_out = paddle.cat([ys_pad, _eos], dim=1)\n", " ys_out = ys_out.masked_fill(mask_pad, eos)\n", " mask_eos = (ys_out == ignore_id)\n", " ys_out = ys_out.masked_fill(mask_eos, eos)\n", " ys_out = ys_out.masked_fill(mask_pad, ignore_id)\n", " return ys_in, ys_out" ] }, { "cell_type": "code", "execution_count": 181, "id": "possible-bulgaria", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [[4232, 2995, 3116, 1209, 565 , 4232, 4232],\n", " [4232, 236 , 1176, 331 , 66 , 3925, 4077],\n", " [4232, 2693, 524 , 234 , 1145, 366 , 4232],\n", " [4232, 3875, 4211, 3062, 700 , 4232, 4232],\n", " [4232, 272 , 987 , 1134, 494 , 2959, 4232],\n", " [4232, 1936, 3715, 120 , 2553, 2695, 2710],\n", " [4232, 25 , 1149, 3930, 4232, 4232, 4232],\n", " [4232, 1753, 1778, 1237, 482 , 3925, 110 ],\n", " [4232, 3703, 2 , 565 , 3827, 4232, 4232],\n", " [4232, 1150, 2734, 10 , 2478, 3490, 4232],\n", " [4232, 426 , 811 , 95 , 489 , 144 , 4232],\n", " [4232, 2313, 2006, 489 , 975 , 4232, 4232],\n", " [4232, 3702, 3414, 205 , 1488, 2966, 1347],\n", " [4232, 70 , 1741, 702 , 1666, 4232, 4232],\n", " [4232, 703 , 1778, 1030, 849 , 4232, 4232],\n", " [4232, 814 , 1674, 115 , 3827, 4232, 4232]])\n", "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [[2995, 3116, 1209, 565, 4232, -1 , -1 ],\n", " [ 236, 1176, 331, 66 , 3925, 4077, 4232],\n", " [2693, 524, 234, 1145, 366, 4232, -1 ],\n", " [3875, 4211, 3062, 700, 4232, -1 , -1 ],\n", " [ 272, 987, 1134, 494, 2959, 4232, -1 ],\n", " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", " [ 25 , 1149, 3930, 4232, -1 , -1 , -1 ],\n", " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", " [3703, 2 , 565, 3827, 4232, -1 , -1 ],\n", " [1150, 2734, 10 , 2478, 3490, 4232, -1 ],\n", " [ 426, 811, 95 , 489, 144, 4232, -1 ],\n", " [2313, 2006, 489, 975, 4232, -1 , -1 ],\n", " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", " [ 70 , 1741, 702, 1666, 4232, -1 , -1 ],\n", " [ 703, 1778, 1030, 849, 4232, -1 , -1 ],\n", " [ 814, 1674, 115, 3827, 4232, -1 , -1 ]])\n" ] } ], "source": [ "ys_pad = text\n", "ys_pad_lens = text_len\n", "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", " model.ignore_id)\n", "ys_in_lens = ys_pad_lens + 1\n", "print(ys_in_pad)\n", "print(ys_out_pad)" ] }, { "cell_type": "code", "execution_count": 285, "id": "north-walter", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n", "True\n", "False\n", "[[-3.76389682e-01 -8.22720408e-01 7.42762923e-01 ... 3.42005253e-01\n", " 1.50350705e-02 4.03372347e-01]\n", " [-8.73864174e-01 -3.13894272e-01 4.19878662e-01 ... 3.77237231e-01\n", " -1.43528014e-01 -1.00236630e+00]\n", " [-4.35050905e-01 3.45046446e-02 -2.87102997e-01 ... 7.72742853e-02\n", " -1.16722476e+00 -2.68485069e-01]\n", " ...\n", " [ 4.24714804e-01 5.88856399e-01 2.02039629e-02 ... 3.74054879e-01\n", " 4.54700664e-02 -3.71394157e-01]\n", " [-3.79784584e-01 -8.10841978e-01 7.57250786e-01 ... 2.60389000e-01\n", " -7.93404877e-04 4.25376773e-01]\n", " [-3.82798851e-01 -8.12067091e-01 7.49434292e-01 ... 2.61730075e-01\n", " -1.04988366e-03 4.26787734e-01]]\n", "---\n", "[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n", " 1.5034772e-02 4.0337229e-01]\n", " [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n", " -1.4352810e-01 -1.0023664e+00]\n", " [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n", " -1.1672243e+00 -2.6848501e-01]\n", " ...\n", " [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n", " 4.5470044e-02 -3.7139410e-01]\n", " [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n", " -7.9347193e-04 4.2537671e-01]\n", " [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n", " -1.0499060e-03 4.2678756e-01]]\n" ] } ], "source": [ "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", " ys_in_lens)\n", "\n", "print(np.allclose(decoder_out.numpy(), torch_decoder_out))\n", "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-6))\n", "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-7))\n", "print(decoder_out.numpy()[0])\n", "print('---')\n", "print(torch_decoder_out[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "armed-cowboy", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "fifty-earth", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "proud-commonwealth", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 183, "id": "assisted-fortune", "metadata": {}, "outputs": [], "source": [ "from paddle import nn\n", "import paddle\n", "from paddle.nn import functional as F\n", "\n", "class LabelSmoothingLoss(nn.Layer):\n", "\n", " def __init__(self,\n", " size: int,\n", " padding_idx: int,\n", " smoothing: float,\n", " normalize_length: bool=False):\n", " super().__init__()\n", " self.size = size\n", " self.padding_idx = padding_idx\n", " self.smoothing = smoothing\n", " self.confidence = 1.0 - smoothing\n", " self.normalize_length = normalize_length\n", " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", "\n", " def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:\n", " \"\"\"Compute loss between x and target.\n", " The model outputs and data labels tensors are flatten to\n", " (batch*seqlen, class) shape and a mask is applied to the\n", " padding part which should not be calculated for loss.\n", " \n", " Args:\n", " x (paddle.Tensor): prediction (batch, seqlen, class)\n", " target (paddle.Tensor):\n", " target signal masked with self.padding_id (batch, seqlen)\n", " Returns:\n", " loss (paddle.Tensor) : The KL loss, scalar float value\n", " \"\"\"\n", " B, T, D = paddle.shape(x)\n", " assert D == self.size\n", " x = x.reshape((-1, self.size))\n", " target = target.reshape([-1])\n", "\n", " # use zeros_like instead of torch.no_grad() for true_dist,\n", " # since no_grad() can not be exported by JIT\n", " true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))\n", " ignore = target == self.padding_idx # (B,)\n", " print(self.smoothing / (self.size - 1))\n", " print(true_dist)\n", "\n", " #target = target * (1 - ignore) # avoid -1 index\n", " target = target.masked_fill(ignore, 0) # avoid -1 index\n", " \n", " \n", " #true_dist += F.one_hot(target, self.size) * self.confidence\n", " target_mask = F.one_hot(target, self.size)\n", " true_dist *= (1 - target_mask)\n", " true_dist += target_mask * self.confidence\n", " \n", "\n", " kl = self.criterion(F.log_softmax(x, axis=1), true_dist)\n", " \n", " #TODO(Hui Zhang): sum not support bool type\n", " #total = len(target) - int(ignore.sum())\n", " total = len(target) - int(ignore.type_as(target).sum())\n", " denom = total if self.normalize_length else B\n", "\n", " #numer = (kl * (1 - ignore)).sum()\n", " numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", " return numer / denom\n" ] }, { "cell_type": "code", "execution_count": 184, "id": "weighted-delight", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.3629489603024576e-05\n", "Tensor(shape=[112, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", " ...,\n", " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363]])\n", "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [41.84146118])\n", "VarType.INT64\n" ] } ], "source": [ "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", "loss_att = criteron(paddle.to_tensor(torch_decoder_out), ys_out_pad.astype('int64'))\n", "print(loss_att)\n", "print(ys_out_pad.dtype)\n", "# tensor(41.8416, device='cuda:0', grad_fn=)" ] }, { "cell_type": "code", "execution_count": 286, "id": "dress-shelter", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [41.84146118])\n", "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [41.84146118])\n", "4233\n", "-1\n", "0.1\n", "False\n" ] } ], "source": [ "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", " ys_in_lens)\n", "\n", "loss_att = model.criterion_att(paddle.to_tensor(torch_decoder_out), ys_out_pad)\n", "print(loss_att)\n", "\n", "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", "print(loss_att)\n", "\n", "print(model.criterion_att.size)\n", "print(model.criterion_att.padding_idx)\n", "print(model.criterion_att.smoothing)\n", "print(model.criterion_att.normalize_length)" ] }, { "cell_type": "code", "execution_count": null, "id": "growing-tooth", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "going-hungary", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "naughty-citizenship", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "experimental-emerald", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "adverse-saskatchewan", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 27, "id": "speaking-shelf", "metadata": {}, "outputs": [], "source": [ "from typing import List\n", "from typing import Optional\n", "from typing import Tuple\n", "\n", "import paddle\n", "from paddle import nn\n", "from typeguard import check_argument_types\n", "\n", "from deepspeech.modules.activation import get_activation\n", "from deepspeech.modules.attention import MultiHeadedAttention\n", "from deepspeech.modules.attention import RelPositionMultiHeadedAttention\n", "from deepspeech.modules.conformer_convolution import ConvolutionModule\n", "from deepspeech.modules.embedding import PositionalEncoding\n", "from deepspeech.modules.embedding import RelPositionalEncoding\n", "from deepspeech.modules.encoder_layer import ConformerEncoderLayer\n", "from deepspeech.modules.encoder_layer import TransformerEncoderLayer\n", "from deepspeech.modules.mask import add_optional_chunk_mask\n", "from deepspeech.modules.mask import make_non_pad_mask\n", "from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward\n", "from deepspeech.modules.subsampling import Conv2dSubsampling4\n", "from deepspeech.modules.subsampling import Conv2dSubsampling6\n", "from deepspeech.modules.subsampling import Conv2dSubsampling8\n", "from deepspeech.modules.subsampling import LinearNoSubsampling\n", "\n", "class BaseEncoder(nn.Layer):\n", " def __init__(\n", " self,\n", " input_size: int,\n", " output_size: int=256,\n", " attention_heads: int=4,\n", " linear_units: int=2048,\n", " num_blocks: int=6,\n", " dropout_rate: float=0.1,\n", " positional_dropout_rate: float=0.1,\n", " attention_dropout_rate: float=0.0,\n", " input_layer: str=\"conv2d\",\n", " pos_enc_layer_type: str=\"abs_pos\",\n", " normalize_before: bool=True,\n", " concat_after: bool=False,\n", " static_chunk_size: int=0,\n", " use_dynamic_chunk: bool=False,\n", " global_cmvn: paddle.nn.Layer=None,\n", " use_dynamic_left_chunk: bool=False, ):\n", " \"\"\"\n", " Args:\n", " input_size (int): input dim, d_feature\n", " output_size (int): dimension of attention, d_model\n", " attention_heads (int): the number of heads of multi head attention\n", " linear_units (int): the hidden units number of position-wise feed\n", " forward\n", " num_blocks (int): the number of encoder blocks\n", " dropout_rate (float): dropout rate\n", " attention_dropout_rate (float): dropout rate in attention\n", " positional_dropout_rate (float): dropout rate after adding\n", " positional encoding\n", " input_layer (str): input layer type.\n", " optional [linear, conv2d, conv2d6, conv2d8]\n", " pos_enc_layer_type (str): Encoder positional encoding layer type.\n", " opitonal [abs_pos, scaled_abs_pos, rel_pos]\n", " normalize_before (bool):\n", " True: use layer_norm before each sub-block of a layer.\n", " False: use layer_norm after each sub-block of a layer.\n", " concat_after (bool): whether to concat attention layer's input\n", " and output.\n", " True: x -> x + linear(concat(x, att(x)))\n", " False: x -> x + att(x)\n", " static_chunk_size (int): chunk size for static chunk training and\n", " decoding\n", " use_dynamic_chunk (bool): whether use dynamic chunk size for\n", " training or not, You can only use fixed chunk(chunk_size > 0)\n", " or dyanmic chunk size(use_dynamic_chunk = True)\n", " global_cmvn (Optional[paddle.nn.Layer]): Optional GlobalCMVN layer\n", " use_dynamic_left_chunk (bool): whether use dynamic left chunk in\n", " dynamic chunk training\n", " \"\"\"\n", " assert check_argument_types()\n", " super().__init__()\n", " self._output_size = output_size\n", "\n", " if pos_enc_layer_type == \"abs_pos\":\n", " pos_enc_class = PositionalEncoding\n", " elif pos_enc_layer_type == \"rel_pos\":\n", " pos_enc_class = RelPositionalEncoding\n", " else:\n", " raise ValueError(\"unknown pos_enc_layer: \" + pos_enc_layer_type)\n", "\n", " if input_layer == \"linear\":\n", " subsampling_class = LinearNoSubsampling\n", " elif input_layer == \"conv2d\":\n", " subsampling_class = Conv2dSubsampling4\n", " elif input_layer == \"conv2d6\":\n", " subsampling_class = Conv2dSubsampling6\n", " elif input_layer == \"conv2d8\":\n", " subsampling_class = Conv2dSubsampling8\n", " else:\n", " raise ValueError(\"unknown input_layer: \" + input_layer)\n", "\n", " self.global_cmvn = global_cmvn\n", " self.embed = subsampling_class(\n", " idim=input_size,\n", " odim=output_size,\n", " dropout_rate=dropout_rate,\n", " pos_enc_class=pos_enc_class(\n", " d_model=output_size, dropout_rate=positional_dropout_rate), )\n", "\n", " self.normalize_before = normalize_before\n", " self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)\n", " self.static_chunk_size = static_chunk_size\n", " self.use_dynamic_chunk = use_dynamic_chunk\n", " self.use_dynamic_left_chunk = use_dynamic_left_chunk\n", "\n", " def output_size(self) -> int:\n", " return self._output_size\n", "\n", " def forward(\n", " self,\n", " xs: paddle.Tensor,\n", " xs_lens: paddle.Tensor,\n", " decoding_chunk_size: int=0,\n", " num_decoding_left_chunks: int=-1,\n", " ) -> Tuple[paddle.Tensor, paddle.Tensor]:\n", " \"\"\"Embed positions in tensor.\n", " Args:\n", " xs: padded input tensor (B, L, D)\n", " xs_lens: input length (B)\n", " decoding_chunk_size: decoding chunk size for dynamic chunk\n", " 0: default for training, use random dynamic chunk.\n", " <0: for decoding, use full chunk.\n", " >0: for decoding, use fixed chunk size as set.\n", " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", " the chunk size is decoding_chunk_size.\n", " >=0: use num_decoding_left_chunks\n", " <0: use all left chunks\n", " Returns:\n", " encoder output tensor, lens and mask\n", " \"\"\"\n", " masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)\n", "\n", " if self.global_cmvn is not None:\n", " xs = self.global_cmvn(xs)\n", " #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor\n", " xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)\n", " #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor\n", " masks = masks.astype(paddle.bool)\n", " #TODO(Hui Zhang): mask_pad = ~masks\n", " mask_pad = masks.logical_not()\n", " chunk_masks = add_optional_chunk_mask(\n", " xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,\n", " decoding_chunk_size, self.static_chunk_size,\n", " num_decoding_left_chunks)\n", " for layer in self.encoders:\n", " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", " if self.normalize_before:\n", " xs = self.after_norm(xs)\n", " # Here we assume the mask is not changed in encoder layers, so just\n", " # return the masks before encoder layers, and the masks will be used\n", " # for cross attention with decoder later\n", " return xs, masks" ] }, { "cell_type": "code", "execution_count": 28, "id": "sharp-municipality", "metadata": {}, "outputs": [], "source": [ "\n", "class ConformerEncoder(BaseEncoder):\n", " \"\"\"Conformer encoder module.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " input_size: int,\n", " output_size: int=256,\n", " attention_heads: int=4,\n", " linear_units: int=2048,\n", " num_blocks: int=6,\n", " dropout_rate: float=0.1,\n", " positional_dropout_rate: float=0.1,\n", " attention_dropout_rate: float=0.0,\n", " input_layer: str=\"conv2d\",\n", " pos_enc_layer_type: str=\"rel_pos\",\n", " normalize_before: bool=True,\n", " concat_after: bool=False,\n", " static_chunk_size: int=0,\n", " use_dynamic_chunk: bool=False,\n", " global_cmvn: nn.Layer=None,\n", " use_dynamic_left_chunk: bool=False,\n", " positionwise_conv_kernel_size: int=1,\n", " macaron_style: bool=True,\n", " selfattention_layer_type: str=\"rel_selfattn\",\n", " activation_type: str=\"swish\",\n", " use_cnn_module: bool=True,\n", " cnn_module_kernel: int=15,\n", " causal: bool=False,\n", " cnn_module_norm: str=\"batch_norm\", ):\n", " \"\"\"Construct ConformerEncoder\n", " Args:\n", " input_size to use_dynamic_chunk, see in BaseEncoder\n", " positionwise_conv_kernel_size (int): Kernel size of positionwise\n", " conv1d layer.\n", " macaron_style (bool): Whether to use macaron style for\n", " positionwise layer.\n", " selfattention_layer_type (str): Encoder attention layer type,\n", " the parameter has no effect now, it's just for configure\n", " compatibility.\n", " activation_type (str): Encoder activation function type.\n", " use_cnn_module (bool): Whether to use convolution module.\n", " cnn_module_kernel (int): Kernel size of convolution module.\n", " causal (bool): whether to use causal convolution or not.\n", " cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']\n", " \"\"\"\n", " assert check_argument_types()\n", " super().__init__(input_size, output_size, attention_heads, linear_units,\n", " num_blocks, dropout_rate, positional_dropout_rate,\n", " attention_dropout_rate, input_layer,\n", " pos_enc_layer_type, normalize_before, concat_after,\n", " static_chunk_size, use_dynamic_chunk, global_cmvn,\n", " use_dynamic_left_chunk)\n", " activation = get_activation(activation_type)\n", "\n", " # self-attention module definition\n", " encoder_selfattn_layer = RelPositionMultiHeadedAttention\n", " encoder_selfattn_layer_args = (attention_heads, output_size,\n", " attention_dropout_rate)\n", " # feed-forward module definition\n", " positionwise_layer = PositionwiseFeedForward\n", " positionwise_layer_args = (output_size, linear_units, dropout_rate,\n", " activation)\n", " # convolution module definition\n", " convolution_layer = ConvolutionModule\n", " convolution_layer_args = (output_size, cnn_module_kernel, activation,\n", " cnn_module_norm, causal)\n", "\n", " self.encoders = nn.ModuleList([\n", " ConformerEncoderLayer(\n", " size=output_size,\n", " self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),\n", " feed_forward=positionwise_layer(*positionwise_layer_args),\n", " feed_forward_macaron=positionwise_layer(\n", " *positionwise_layer_args) if macaron_style else None,\n", " conv_module=convolution_layer(*convolution_layer_args)\n", " if use_cnn_module else None,\n", " dropout_rate=dropout_rate,\n", " normalize_before=normalize_before,\n", " concat_after=concat_after) for _ in range(num_blocks)\n", " ])\n" ] }, { "cell_type": "code", "execution_count": 29, "id": "tutorial-syndication", "metadata": {}, "outputs": [], "source": [ "from deepspeech.frontend.utility import load_cmvn\n", "from deepspeech.modules.cmvn import GlobalCMVN\n", "\n", "configs=cfg.model\n", "mean, istd = load_cmvn(configs['cmvn_file'],\n", " configs['cmvn_file_type'])\n", "global_cmvn = GlobalCMVN(\n", " paddle.to_tensor(mean, dtype=paddle.float),\n", " paddle.to_tensor(istd, dtype=paddle.float))\n", "\n", "\n", "input_dim = configs['input_dim']\n", "vocab_size = configs['output_dim']\n", "encoder_type = configs.get('encoder', 'transformer')\n", " \n", "encoder = ConformerEncoder(\n", " input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])" ] }, { "cell_type": "code", "execution_count": 30, "id": "fuzzy-register", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "o = global_cmvn(feat)\n", "o2 = model.encoder.global_cmvn(feat)\n", "print(np.allclose(o.numpy(), o2.numpy()))" ] }, { "cell_type": "code", "execution_count": null, "id": "explicit-triumph", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "humanitarian-belgium", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "dying-proposal", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "honest-quick", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "bound-cholesterol", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "viral-packaging", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 203, "id": "balanced-locator", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[16, 1, 207], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", " [[[True , True , True , ..., True , True , True ]],\n", "\n", " [[True , True , True , ..., True , True , True ]],\n", "\n", " [[True , True , True , ..., True , False, False]],\n", "\n", " ...,\n", "\n", " [[True , True , True , ..., False, False, False]],\n", "\n", " [[True , True , True , ..., False, False, False]],\n", "\n", " [[True , True , True , ..., False, False, False]]])\n" ] } ], "source": [ "from deepspeech.modules.mask import make_non_pad_mask\n", "from deepspeech.modules.mask import make_pad_mask\n", "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", "print(masks)" ] }, { "cell_type": "code", "execution_count": 204, "id": "induced-proposition", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[16, 207, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[[-0.53697914, -0.19910523, -0.34997201, ..., -0.82427669, -1.02650309, -0.96300691],\n", " [-0.04464225, 0.23176001, -0.32538742, ..., -0.90158713, -1.03248465, -0.75986791],\n", " [ 0.50035292, 0.22691160, -0.73052198, ..., -1.00552964, -0.87123060, -1.03062117],\n", " ...,\n", " [-0.40023831, -0.14325078, -0.57947433, ..., -1.07178426, -1.28059900, -1.05180073],\n", " [ 0.15755332, -0.00184949, -0.28702953, ..., -1.10898709, -0.94518697, -0.72506356],\n", " [-0.47520429, -1.39415145, -0.25754252, ..., -1.13649082, -1.19430351, -1.22903371]],\n", "\n", " [[ 0.95454037, 0.36427975, -1.38908529, ..., -1.16366839, -1.28453600, -1.20151031],\n", " [-0.08573537, -1.05785275, -0.89172721, ..., -0.96440506, -1.12547100, -1.25990939],\n", " [ 0.47653601, 0.32886592, -0.59200549, ..., -1.19421589, -1.14302588, -1.02422845],\n", " ...,\n", " [-0.47431335, -0.33558893, -0.72325647, ..., -1.45058632, -1.39574063, -1.04641151],\n", " [ 0.36112556, 0.10380996, -1.15994537, ..., -1.04394984, -1.02212358, -1.02083635],\n", " [-1.27172923, -2.14601755, -0.75676596, ..., -0.97822225, -0.93785471, -1.03707945]],\n", "\n", " [[-1.54652190, -1.01517177, -0.88900733, ..., -0.48522446, -0.75163364, -0.67765164],\n", " [-0.76100892, -0.73351598, -0.91587651, ..., -0.24835993, -0.58927339, -0.73722762],\n", " [-0.02471367, 0.17015894, -0.42326337, ..., -0.33203802, -0.76695800, -0.71651691],\n", " ...,\n", " [-1.70319796, -1.25910866, -1.14492917, ..., -1.18101490, -1.11631835, -0.93108195],\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", "\n", " ...,\n", "\n", " [[ 0.64982772, 0.26116797, -0.84196597, ..., -0.87213463, -1.10728693, -1.32531130],\n", " [ 0.35391113, -0.01584581, -0.40424931, ..., -0.99173468, -1.07270539, -1.19239008],\n", " [ 0.37704495, -0.06278508, -0.11467686, ..., -1.10212946, -1.09524000, -1.11815071],\n", " ...,\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", "\n", " [[ 0.04445776, -0.17546852, -0.67475224, ..., -0.49801198, -0.56782746, -0.77852231],\n", " [-1.34279025, -0.80342549, -0.90457231, ..., -0.65901577, -0.72549772, -0.62796098],\n", " [-0.76252806, -0.13071291, -0.13280024, ..., -0.56132573, -0.60587686, -0.72114766],\n", " ...,\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", "\n", " [[-1.07980299, -1.08341801, -1.17969072, ..., -0.17757270, -0.43746525, -0.04000654],\n", " [ 0.92353648, 0.63770926, -0.52810186, ..., -0.12927933, -0.20342292, 0.16655664],\n", " [ 0.49337494, -0.00911332, -0.73301607, ..., 0.10074048, -0.09811471, -0.00923573],\n", " ...,\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]]])\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "print(xs)" ] }, { "cell_type": "code", "execution_count": 205, "id": "cutting-julian", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[16, 256, 51, 19], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [[[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0.00209083],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0.01194306, 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0.04610471, 0. ],\n", " [0. , 0. , 0. , ..., 0.00967231, 0.04613467, 0. ]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " ...,\n", "\n", " [[0.22816099, 0.24614786, 0.25304127, ..., 0.20401822, 0.23248228, 0.31190544],\n", " [0.13587360, 0.28877240, 0.27991283, ..., 0.19210319, 0.20346391, 0.19934426],\n", " [0.25739068, 0.39348233, 0.27877361, ..., 0.27482539, 0.19302306, 0.23810163],\n", " ...,\n", " [0.11939213, 0.28473237, 0.33082074, ..., 0.23838061, 0.22104350, 0.23905794],\n", " [0.17387670, 0.20402060, 0.40263173, ..., 0.24782266, 0.26742202, 0.15426503],\n", " [0. , 0.29080707, 0.27725950, ..., 0.17539823, 0.18478745, 0.22483408]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.35446781, 0.38861471, 0.39724261, ..., 0.38680089, 0.33568040, 0.34552398],\n", " [0.41739127, 0.51038563, 0.41729912, ..., 0.33992639, 0.37081629, 0.35109508],\n", " [0.36116859, 0.40744874, 0.48490953, ..., 0.34848654, 0.32321057, 0.35188958],\n", " ...,\n", " [0.23143977, 0.38021481, 0.51526314, ..., 0.36499465, 0.37411752, 0.39986172],\n", " [0.34678638, 0.40238205, 0.50076538, ..., 0.36184520, 0.31596646, 0.36334658],\n", " [0.36498138, 0.37943166, 0.51718897, ..., 0.31798238, 0.33656698, 0.34130475]]],\n", "\n", "\n", " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.01456045, 0.09447514, 0. , ..., 0. , 0. , 0. ],\n", " [0.01500242, 0.02963220, 0. , ..., 0. , 0. , 0. ],\n", " [0.03295187, 0. , 0. , ..., 0.04584959, 0.02043908, 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0.04425837],\n", " [0. , 0. , 0.02556529, ..., 0. , 0.00900441, 0.04908358]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0.11141267, 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " ...,\n", "\n", " [[0.33696529, 0.38526866, 0.32900479, ..., 0.28703830, 0.23351061, 0.19004467],\n", " [0.13575366, 0.35783342, 0.33573425, ..., 0.22081660, 0.15854910, 0.13587447],\n", " [0.21928655, 0.28900093, 0.28255141, ..., 0.20602837, 0.23927397, 0.21909429],\n", " ...,\n", " [0.23291890, 0.39096734, 0.36399242, ..., 0.20598020, 0.25373828, 0.23137446],\n", " [0.18739152, 0.30793777, 0.30296701, ..., 0.27250600, 0.25191751, 0.20836820],\n", " [0.22454213, 0.41402060, 0.54082996, ..., 0.31874508, 0.25079906, 0.25938687]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.26456982, 0.49519050, 0.56702250, ..., 0.30954638, 0.35292268, 0.32668519],\n", " [0.21576807, 0.51833367, 0.49183372, ..., 0.36043224, 0.38523889, 0.36154741],\n", " [0.20067888, 0.42784205, 0.52817714, ..., 0.31871423, 0.32452232, 0.31036487],\n", " ...,\n", " [0.49855131, 0.51001430, 0.52278662, ..., 0.36450142, 0.34338164, 0.33602941],\n", " [0.41233343, 0.55517823, 0.52827710, ..., 0.40675971, 0.33873138, 0.36724189],\n", " [0.40820011, 0.46187383, 0.47338152, ..., 0.38690975, 0.36039269, 0.38022059]]],\n", "\n", "\n", " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0. , 0.00578516, 0. , ..., 0.00748384, 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0.03035110, 0. , 0.00026720],\n", " [0.00094807, 0. , 0. , ..., 0.00795512, 0. , 0. ],\n", " ...,\n", " [0.02032628, 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0.01080076, 0. ],\n", " [0.18470290, 0. , 0. , ..., 0.05058352, 0.09475817, 0.05914564]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " ...,\n", "\n", " [[0.38708323, 0.28021947, 0.35892880, ..., 0.16595127, 0.16031364, 0.21136315],\n", " [0.15595171, 0.30544323, 0.24666184, ..., 0.22675267, 0.25765014, 0.19682154],\n", " [0.29517862, 0.41209796, 0.20063159, ..., 0.17595036, 0.22536841, 0.22214051],\n", " ...,\n", " [0.24744980, 0.26258564, 0.38654143, ..., 0.23620218, 0.23157144, 0.18514194],\n", " [0.25714791, 0.29592845, 0.47744542, ..., 0.23545510, 0.25072727, 0.20976165],\n", " [1.20154655, 0.84644288, 0.73385584, ..., 1.02517247, 0.95309550, 1.00134516]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.45013186, 0.47484034, 0.40540054, ..., 0.19346163, 0.17825794, 0.14776605],\n", " [0.47545874, 0.48186573, 0.36760187, ..., 0.27809089, 0.32997063, 0.32337096],\n", " [0.46160024, 0.40050328, 0.39060861, ..., 0.36612910, 0.35242686, 0.29738861],\n", " ...,\n", " [0.55148494, 0.51017821, 0.40132499, ..., 0.38948193, 0.35737294, 0.33088297],\n", " [0.41972569, 0.45475486, 0.45320493, ..., 0.38343129, 0.40125814, 0.36180776],\n", " [0.34279808, 0.31606171, 0.44701228, ..., 0.21665487, 0.23984617, 0.23903391]]],\n", "\n", "\n", " ...,\n", "\n", "\n", " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0.04178291, 0. , 0.01580476, ..., 0. , 0.02250817, 0. ],\n", " [0.04323414, 0.07786420, 0. , ..., 0.01634724, 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.03209178, 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0.13563479, 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " ...,\n", "\n", " [[0. , 0.25187218, 0.24979387, ..., 0.24774717, 0.22354351, 0.19149347],\n", " [0.16540922, 0.19585510, 0.19812922, ..., 0.27344131, 0.20928150, 0.26150429],\n", " [0.10494646, 0.06329897, 0.33843631, ..., 0.25138417, 0.12470355, 0.23926635],\n", " ...,\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.11428106, 0.45667490, 0.46820879, ..., 0.32057840, 0.33578536, 0.39012644],\n", " [0.10441341, 0.45739070, 0.46107352, ..., 0.38467997, 0.38291249, 0.36685589],\n", " [0.19867736, 0.35519636, 0.44313061, ..., 0.40679252, 0.38067645, 0.30645671],\n", " ...,\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n", "\n", "\n", " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.02465414, 0. , 0. , ..., 0. , 0. , 0.03390232],\n", " [0. , 0. , 0.01830704, ..., 0.05166877, 0.00948385, 0.07453502],\n", " [0.09921519, 0. , 0.01587192, ..., 0.01620276, 0.05140074, 0.00192392],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " ...,\n", "\n", " [[0.40034360, 0.25306445, 0.20217699, ..., 0.09816189, 0.07064310, 0.04974059],\n", " [0.12567598, 0.21030979, 0.11181555, ..., 0.04278110, 0.11968569, 0.12005232],\n", " [0.28786880, 0.24030517, 0.22565845, ..., 0. , 0.06418110, 0.05872961],\n", " ...,\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.38404641, 0.30990323, 0.37156230, ..., 0.18125033, 0.15050662, 0.19619957],\n", " [0.47285745, 0.40528792, 0.39718056, ..., 0.24709940, 0.04565683, 0.11500744],\n", " [0.32620737, 0.30072594, 0.30477354, ..., 0.23529193, 0.21356541, 0.16985542],\n", " ...,\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n", "\n", "\n", " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.03343770, 0.00123780, 0.05297198, ..., 0.07271163, 0.08656286, 0.14493589],\n", " [0.11043239, 0.06143146, 0.06362963, ..., 0.08127750, 0.06259022, 0.08315435],\n", " [0.01767678, 0.00201111, 0.07875030, ..., 0.06963293, 0.08979890, 0.05326346],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.10033827, 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0.15627117, 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0.05144687, 0. , 0. , ..., 0. , 0. , 0.00436414],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " ...,\n", "\n", " [[0.25142455, 0.45964020, 0.37346074, ..., 0.04763087, 0. , 0. ],\n", " [0.19760093, 0.26626948, 0.11190540, ..., 0.03044968, 0. , 0. ],\n", " [0.16340607, 0.32938001, 0.25689697, ..., 0.05569421, 0. , 0. ],\n", " ...,\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", "\n", " [[0. , 0. , 0. , ..., 0. , 0.02218930, 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0.02848953],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " ...,\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", "\n", " [[0.25810039, 0.63016868, 0.37037861, ..., 0.18704373, 0.08269356, 0.09912672],\n", " [0.17292863, 0.50678611, 0.40738991, ..., 0.16006103, 0.11725381, 0.09940521],\n", " [0.24175072, 0.41616210, 0.41256818, ..., 0.13519743, 0.07912572, 0.12846369],\n", " ...,\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]]])\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", "\n", "\n", "#xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", "# print(xs)\n", "\n", "x = xs.unsqueeze(1)\n", "x = model.encoder.embed.conv(x)\n", "print(x)" ] }, { "cell_type": "code", "execution_count": 206, "id": "friendly-nightlife", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [[[-0.03426375, 0.14291267, -0.06718873, ..., 0.09064753, 0.01809387, -0.04340880],\n", " [-0.05007839, 0.11054724, -0.10399298, ..., 0.11457238, 0.04244684, -0.01249714],\n", " [-0.10695291, 0.16910909, -0.08352133, ..., 0.07710276, 0.01168563, -0.03584499],\n", " ...,\n", " [-0.06060536, 0.14455931, -0.05470302, ..., 0.05364908, 0.03033342, -0.02610814],\n", " [-0.08505894, 0.13611752, -0.11132983, ..., 0.13079923, 0.01580139, -0.02281028],\n", " [-0.10604677, 0.14714901, -0.10885533, ..., 0.08543444, 0.03719445, -0.04634233]],\n", "\n", " [[-0.12392755, 0.14486063, -0.05674079, ..., 0.02573164, 0.03128851, 0.00545091],\n", " [-0.04775286, 0.08473608, -0.08507854, ..., 0.04573154, 0.04240163, 0.01053247],\n", " [-0.05940291, 0.10023535, -0.08143730, ..., 0.03596500, 0.01673085, 0.02089563],\n", " ...,\n", " [-0.09222981, 0.15823206, -0.07700447, ..., 0.08122957, 0.03136991, -0.00646474],\n", " [-0.07331756, 0.14482647, -0.07838815, ..., 0.10869440, 0.01356864, -0.02777974],\n", " [-0.07937264, 0.20143102, -0.05544947, ..., 0.10287814, 0.00608235, -0.04799180]],\n", "\n", " [[-0.03670349, 0.08931590, -0.08718812, ..., 0.01314050, 0.00642052, 0.00573716],\n", " [ 0.01089254, 0.11146393, -0.10263617, ..., 0.05070438, 0.01960694, 0.03521532],\n", " [-0.02182280, 0.11443964, -0.06678198, ..., 0.04327708, 0.00861394, 0.02871092],\n", " ...,\n", " [-0.06792898, 0.14376275, -0.07899005, ..., 0.11248926, 0.03208683, -0.03264240],\n", " [-0.07884051, 0.17024788, -0.08583611, ..., 0.09028331, 0.03588808, -0.02075090],\n", " [-0.13792302, 0.27163863, -0.23930418, ..., 0.13391261, 0.07521040, -0.08621951]],\n", "\n", " ...,\n", "\n", " [[-0.02446348, 0.11595841, -0.03591986, ..., 0.06288970, 0.02895011, -0.06532725],\n", " [-0.05378424, 0.12607370, -0.09023033, ..., 0.09078894, 0.01035743, 0.03701983],\n", " [-0.04566649, 0.14275314, -0.06686870, ..., 0.09890588, -0.00612222, 0.03439377],\n", " ...,\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n", "\n", " [[-0.01012144, 0.03909408, -0.07077143, ..., 0.00452683, -0.01377654, 0.02897627],\n", " [-0.00519154, 0.03594019, -0.06831125, ..., 0.05693541, -0.00406374, 0.04561640],\n", " [-0.01762631, 0.00500899, -0.05886075, ..., 0.02112178, -0.00729015, 0.02782153],\n", " ...,\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n", "\n", " [[-0.03411558, -0.04318277, -0.08497842, ..., -0.04886402, 0.04296734, 0.06151697],\n", " [ 0.00263296, -0.06913657, -0.08993219, ..., -0.00149064, 0.05696633, 0.03304394],\n", " [-0.01818341, -0.01178640, -0.09679577, ..., -0.00870231, 0.00362198, 0.01916483],\n", " ...,\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]]])\n", "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [[[-0.54821998, 2.28660274, -1.07501972, ..., 1.45036042, 0.28950194, -0.69454080],\n", " [-0.80125421, 1.76875579, -1.66388774, ..., 1.83315802, 0.67914939, -0.19995420],\n", " [-1.71124649, 2.70574546, -1.33634126, ..., 1.23364413, 0.18697014, -0.57351983],\n", " ...,\n", " [-0.96968573, 2.31294894, -0.87524825, ..., 0.85838526, 0.48533469, -0.41773027],\n", " [-1.36094308, 2.17788029, -1.78127730, ..., 2.09278774, 0.25282228, -0.36496443],\n", " [-1.69674826, 2.35438418, -1.74168527, ..., 1.36695099, 0.59511113, -0.74147725]],\n", "\n", " [[-1.98284078, 2.31777000, -0.90785271, ..., 0.41170627, 0.50061619, 0.08721463],\n", " [-0.76404583, 1.35577726, -1.36125672, ..., 0.73170459, 0.67842603, 0.16851945],\n", " [-0.95044655, 1.60376561, -1.30299675, ..., 0.57544005, 0.26769355, 0.33433008],\n", " ...,\n", " [-1.47567701, 2.53171301, -1.23207152, ..., 1.29967308, 0.50191855, -0.10343577],\n", " [-1.17308092, 2.31722355, -1.25421047, ..., 1.73911047, 0.21709818, -0.44447583],\n", " [-1.26996231, 3.22289634, -0.88719147, ..., 1.64605021, 0.09731755, -0.76786882]],\n", "\n", " [[-0.58725590, 1.42905438, -1.39500988, ..., 0.21024795, 0.10272825, 0.09179455],\n", " [ 0.17428070, 1.78342295, -1.64217877, ..., 0.81127012, 0.31371105, 0.56344515],\n", " [-0.34916472, 1.83103430, -1.06851172, ..., 0.69243336, 0.13782299, 0.45937473],\n", " ...,\n", " [-1.08686376, 2.30020404, -1.26384079, ..., 1.79982817, 0.51338923, -0.52227837],\n", " [-1.26144814, 2.72396612, -1.37337780, ..., 1.44453299, 0.57420933, -0.33201432],\n", " [-2.20676827, 4.34621811, -3.82886696, ..., 2.14260173, 1.20336640, -1.37951219]],\n", "\n", " ...,\n", "\n", " [[-0.39141566, 1.85533464, -0.57471782, ..., 1.00623512, 0.46320182, -1.04523599],\n", " [-0.86054784, 2.01717925, -1.44368529, ..., 1.45262301, 0.16571884, 0.59231722],\n", " [-0.73066384, 2.28405023, -1.06989920, ..., 1.58249414, -0.09795550, 0.55030036],\n", " ...,\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n", "\n", " [[-0.16194311, 0.62550521, -1.13234293, ..., 0.07242929, -0.22042468, 0.46362036],\n", " [-0.08306468, 0.57504302, -1.09298003, ..., 0.91096652, -0.06501988, 0.72986233],\n", " [-0.28202093, 0.08014385, -0.94177192, ..., 0.33794850, -0.11664233, 0.44514441],\n", " ...,\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n", "\n", " [[-0.54584920, -0.69092435, -1.35965478, ..., -0.78182435, 0.68747747, 0.98427159],\n", " [ 0.04212743, -1.10618520, -1.43891501, ..., -0.02385022, 0.91146135, 0.52870303],\n", " [-0.29093450, -0.18858244, -1.54873240, ..., -0.13923697, 0.05795169, 0.30663735],\n", " ...,\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]]])\n", "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n", " [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n", " [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n", " ...,\n", " [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n", " [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n", " [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n" ] } ], "source": [ "b, c, t, f = paddle.shape(x)\n", "x = model.encoder.embed.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))\n", "print(x)\n", "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", "print(x)\n", "print(pos_emb)" ] }, { "cell_type": "code", "execution_count": 207, "id": "guilty-cache", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n", " [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n", " [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n", " ...,\n", " [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n", " [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n", " [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n" ] } ], "source": [ "print(pos_emb)" ] }, { "cell_type": "code", "execution_count": 208, "id": "iraqi-payday", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", " 0.0000000e+00 1.0000000e+00]\n", " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", " 1.0746076e-04 1.0000000e+00]\n", " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", " 2.1492151e-04 1.0000000e+00]\n", " ...\n", " [ 9.5625257e-01 -2.9254240e-01 4.8925215e-01 ... 8.3807874e-01\n", " 5.1154459e-01 8.5925674e-01]\n", " [ 2.7049953e-01 -9.6272010e-01 9.9170387e-01 ... 8.3801574e-01\n", " 5.1163691e-01 8.5920173e-01]\n", " [-6.6394955e-01 -7.4777740e-01 6.9544029e-01 ... 8.3795273e-01\n", " 5.1172924e-01 8.5914677e-01]]]\n", "[1, 5000, 256]\n" ] } ], "source": [ "import torch\n", "import math\n", "import numpy as np\n", "\n", "max_len=5000\n", "d_model=256\n", "\n", "pe = torch.zeros(max_len, d_model)\n", "position = torch.arange(0, max_len,\n", " dtype=torch.float32).unsqueeze(1)\n", "toruch_position = position\n", "div_term = torch.exp(\n", " torch.arange(0, d_model, 2, dtype=torch.float32) *\n", " -(math.log(10000.0) / d_model))\n", "tourch_div_term = div_term.cpu().detach().numpy()\n", "\n", "torhc_sin = torch.sin(position * div_term)\n", "torhc_cos = torch.cos(position * div_term)\n", "\n", "np_sin = np.sin((position * div_term).cpu().detach().numpy())\n", "np_cos = np.cos((position * div_term).cpu().detach().numpy())\n", "pe[:, 0::2] = torhc_sin\n", "pe[:, 1::2] = torhc_cos\n", "pe = pe.unsqueeze(0) \n", "tourch_pe = pe.cpu().detach().numpy()\n", "print(tourch_pe)\n", "bak_pe = model.encoder.embed.pos_enc.pe\n", "print(bak_pe.shape)\n", "model.encoder.embed.pos_enc.pe = paddle.to_tensor(tourch_pe)" ] }, { "cell_type": "code", "execution_count": 210, "id": "exempt-cloud", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", "\n", "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", "#print(xs)\n", "data = np.load(\".notebook/enc_embed.npz\")\n", "torch_pos_emb=data['pos_emb']\n", "torch_xs = data['embed_out']\n", "print(np.allclose(xs.numpy(), torch_xs))\n", "print(np.allclose(pos_emb.numpy(), torch_pos_emb))" ] }, { "cell_type": "code", "execution_count": null, "id": "composite-involvement", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 269, "id": "handed-harris", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n", "True\n", "True\n", "True\n", "True\n", "True\n", "False\n", "True\n", "[256, 2048]\n", "[2048]\n", "[2048, 256]\n", "[256]\n", "--------ff-------\n", "True\n", "False\n", "False\n", "False\n", "False\n", "True\n", "linear_714.w_0 True\n", "linear_714.b_0 True\n", "linear_715.w_0 True\n", "linear_715.b_0 True\n", "False\n", "True\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", "\n", "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", "masks = masks.astype(paddle.bool)\n", "mask_pad = masks.logical_not()\n", "decoding_chunk_size=0\n", "num_decoding_left_chunks=-1\n", "chunk_masks = add_optional_chunk_mask(\n", " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", " decoding_chunk_size, model.encoder.static_chunk_size,\n", " num_decoding_left_chunks)\n", "\n", "#print(chunk_masks)\n", "data = np.load(\".notebook/enc_embed.npz\")\n", "torch_pos_emb=data['pos_emb']\n", "torch_xs = data['embed_out']\n", "torch_chunk_masks = data['chunk_masks']\n", "torch_mask_pad = data['mask_pad']\n", "print(np.allclose(xs.numpy(), torch_xs))\n", "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", "\n", "for layer in model.encoder.encoders:\n", " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", " print(layer.feed_forward_macaron is not None)\n", " print(layer.normalize_before)\n", " \n", " data = np.load('.notebook/enc_0_norm_ff.npz')\n", " t_norm_ff = data['norm_ff']\n", " t_xs = data['xs']\n", " \n", " \n", " x = xs\n", " print(np.allclose(t_xs, x.numpy()))\n", " residual = x\n", " print(np.allclose(t_xs, residual.numpy()))\n", " x_nrom = layer.norm_ff_macaron(x)\n", " print(np.allclose(t.numpy(), x_nrom.numpy()))\n", " print(np.allclose(t_norm_ff, x_nrom.numpy()))\n", "# for n, p in layer.norm_ff_macaron.state_dict().items():\n", "# print(n, p)\n", "# pass\n", "\n", " layer.eval()\n", " x_nrom = paddle.to_tensor(t_norm_ff)\n", " print(np.allclose(t_norm_ff, x_nrom.numpy()))\n", " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_nrom)\n", " \n", " ps=[]\n", " for n, p in layer.feed_forward_macaron.state_dict().items():\n", " #print(n, p)\n", " ps.append(p)\n", " print(p.shape)\n", " pass\n", "\n", " x_nrom = paddle.to_tensor(t_norm_ff)\n", " ff_l_x = layer.feed_forward_macaron.w_1(x_nrom)\n", " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", " data = np.load('.notebook/enc_0_ff_out.npz', allow_pickle=True)\n", " t_norm_ff = data['norm_ff']\n", " t_ff_out = data['ff_out']\n", " t_ff_l_x = data['ff_l_x']\n", " t_ff_l_a_x = data['ff_l_a_x']\n", " t_ff_l_a_l_x = data['ff_l_a_l_x']\n", " t_ps = data['ps']\n", " \n", " print(\"--------ff-------\")\n", " print(np.allclose(x_nrom.numpy(), t_norm_ff))\n", " print(np.allclose(x.numpy(), t_ff_out))\n", " print(np.allclose(ff_l_x.numpy(), t_ff_l_x))\n", " print(np.allclose(ff_l_a_x.numpy(), t_ff_l_a_x))\n", " print(np.allclose(ff_l_a_l_x.numpy(), t_ff_l_a_l_x))\n", " \n", " print(np.allclose(ff_l_x.numpy(), t_ff_l_x, atol=1e-6))\n", " for p, t_p in zip(ps, t_ps):\n", " print(p.name, np.allclose(p.numpy(), t_p.T))\n", " \n", " \n", "# residual = x\n", "# x = layer.norm_mha(x)\n", "# x_q = x\n", " \n", " data = np.load('.notebook/enc_0_selattn_out.npz', allow_pickle=True)\n", " tx_q = data['x_q']\n", " tx = data['x']\n", " tpos_emb=data['pos_emb']\n", " tmask=data['mask']\n", " tt_x_att=data['x_att']\n", " x_q = paddle.to_tensor(tx_q)\n", " x = paddle.to_tensor(tx)\n", " pos_emb = paddle.to_tensor(tpos_emb)\n", " mask = paddle.to_tensor(tmask)\n", " \n", " x_att = layer.self_attn(x_q, x, x, pos_emb, mask)\n", " print(np.allclose(x_att.numpy(), t_x_att))\n", " print(np.allclose(x_att.numpy(), t_x_att, atol=1e-6))\n", " \n", " break" ] }, { "cell_type": "code", "execution_count": 270, "id": "sonic-thumb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n", "False\n", "True\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", "\n", "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", "masks = masks.astype(paddle.bool)\n", "mask_pad = masks.logical_not()\n", "decoding_chunk_size=0\n", "num_decoding_left_chunks=-1\n", "chunk_masks = add_optional_chunk_mask(\n", " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", " decoding_chunk_size, model.encoder.static_chunk_size,\n", " num_decoding_left_chunks)\n", "\n", "#print(chunk_masks)\n", "data = np.load(\".notebook/enc_embed.npz\")\n", "torch_pos_emb=data['pos_emb']\n", "torch_xs = data['embed_out']\n", "torch_chunk_masks = data['chunk_masks']\n", "torch_mask_pad = data['mask_pad']\n", "print(np.allclose(xs.numpy(), torch_xs))\n", "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", "\n", "\n", "for layer in model.encoder.encoders:\n", " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", " break\n", "data = np.load('.notebook/enc_0.npz')\n", "torch_xs = data['enc_0']\n", "print(np.allclose(xs.numpy(), torch_xs))\n", "print(np.allclose(xs.numpy(), torch_xs, atol=1e-6))\n" ] }, { "cell_type": "code", "execution_count": 273, "id": "brave-latino", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n", "--------layers_______\n", "False\n", "True\n", "[[-0.70194244 0.56254214 0.6880346 ... 1.1237319 0.7803924\n", " 1.1369387 ]\n", " [-0.7787783 0.3912667 0.71887773 ... 1.251882 0.886168\n", " 1.3173451 ]\n", " [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n", " 1.2903278 ]\n", " ...\n", " [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n", " 1.3239655 ]\n", " [-1.1654118 0.6819967 0.6939453 ... 1.2238353 0.8028295\n", " 1.4506507 ]\n", " [-1.2732092 0.7145806 0.75819594 ... 0.94154835 0.8774845\n", " 1.2623049 ]]\n", "xxxxxx\n", "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", " 1.1369387 ]\n", " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", " 1.3173454 ]\n", " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", " 1.2903274 ]\n", " ...\n", " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", " 1.3239657 ]\n", " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", " 1.4506509 ]\n", " [-1.273209 0.71458095 0.75819623 ... 0.9415484 0.8774842\n", " 1.2623055 ]]\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", "\n", "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", "masks = masks.astype(paddle.bool)\n", "mask_pad = masks.logical_not()\n", "decoding_chunk_size=0\n", "num_decoding_left_chunks=-1\n", "chunk_masks = add_optional_chunk_mask(\n", " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", " decoding_chunk_size, model.encoder.static_chunk_size,\n", " num_decoding_left_chunks)\n", "\n", "#print(chunk_masks)\n", "data = np.load(\".notebook/enc_embed.npz\")\n", "torch_pos_emb=data['pos_emb']\n", "torch_xs = data['embed_out']\n", "torch_chunk_masks = data['chunk_masks']\n", "torch_mask_pad = data['mask_pad']\n", "print(np.allclose(xs.numpy(), torch_xs))\n", "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", "\n", "print(\"--------layers_______\")\n", "i =0\n", "for layer in model.encoder.encoders:\n", " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", " i+=1\n", "# if i == 2:\n", "# data = np.load('.notebook/enc_2.npz')\n", "# torch_xs = data['enc_2']\n", "# print(np.allclose(xs.numpy(), torch_xs))\n", "# print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n", "# print(xs[0].numpy())\n", "# print('xxxxxx')\n", "# print(torch_xs[0])\n", "# print('----i==2')\n", "data = np.load('.notebook/enc_all.npz')\n", "torch_xs = data['enc_all']\n", "print(np.allclose(xs.numpy(), torch_xs))\n", "print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n", "print(xs[0].numpy())\n", "print('xxxxxx')\n", "print(torch_xs[0])" ] }, { "cell_type": "code", "execution_count": 64, "id": "municipal-stock", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 278, "id": "macro-season", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.7019424 0.5625421 0.68803453 ... 1.1237317 0.7803923\n", " 1.1369386 ]\n", " [-0.7787783 0.39126673 0.71887773 ... 1.251882 0.886168\n", " 1.3173451 ]\n", " [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n", " 1.2903278 ]\n", " ...\n", " [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n", " 1.3239655 ]\n", " [-1.1654117 0.68199664 0.6939452 ... 1.2238352 0.8028294\n", " 1.4506506 ]\n", " [-1.2732091 0.71458054 0.7581958 ... 0.9415482 0.8774844\n", " 1.2623048 ]]\n", "---\n", "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", " 1.1369387 ]\n", " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", " 1.3173454 ]\n", " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", " 1.2903274 ]\n", " ...\n", " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", " 1.3239657 ]\n", " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", " 1.4506509 ]\n", " [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n", " 1.2623053 ]]\n", "False\n", "True\n", "False\n" ] } ], "source": [ "encoder_out, mask = model.encoder(feat, feat_len)\n", "print(encoder_out.numpy()[0])\n", "print(\"---\")\n", "print(torch_encoder_out[0])\n", "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5))\n", "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6))" ] }, { "cell_type": "code", "execution_count": null, "id": "associate-sampling", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 5 }