You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/.notebook/u2_model.ipynb

4609 lines
257 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "choice-grade",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/DeepSpeech-2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "broke-broad",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"register user softmax to paddle, remove this when fixed!\n",
"register user log_softmax to paddle, remove this when fixed!\n",
"register user sigmoid to paddle, remove this when fixed!\n",
"register user log_sigmoid to paddle, remove this when fixed!\n",
"register user relu to paddle, remove this when fixed!\n",
"override cat of paddle if exists or register, remove this when fixed!\n",
"override item of paddle.Tensor if exists or register, remove this when fixed!\n",
"override long of paddle.Tensor if exists or register, remove this when fixed!\n",
"override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
"override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"override eq of paddle if exists or register, remove this when fixed!\n",
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"register user view to paddle.Tensor, remove this when fixed!\n",
"register user view_as to paddle.Tensor, remove this when fixed!\n",
"register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"register user fill_ to paddle.Tensor, remove this when fixed!\n",
"register user repeat to paddle.Tensor, remove this when fixed!\n",
"register user softmax to paddle.Tensor, remove this when fixed!\n",
"register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"register user relu to paddle.Tensor, remove this when fixed!\n",
"register user type_as to paddle.Tensor, remove this when fixed!\n",
"register user to to paddle.Tensor, remove this when fixed!\n",
"register user float to paddle.Tensor, remove this when fixed!\n",
"register user tolist to paddle.Tensor, remove this when fixed!\n",
"register user glu to paddle.nn.functional, remove this when fixed!\n",
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"register user Module to paddle.nn, remove this when fixed!\n",
"register user ModuleList to paddle.nn, remove this when fixed!\n",
"register user GLU to paddle.nn, remove this when fixed!\n",
"register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"register user export to paddle.jit, remove this when fixed!\n"
]
}
],
"source": [
"import numpy as np\n",
"import paddle\n",
"from yacs.config import CfgNode as CN\n",
"\n",
"from deepspeech.models.u2 import U2Model\n",
"from deepspeech.utils.layer_tools import print_params\n",
"from deepspeech.utils.layer_tools import summary"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "permanent-summary",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"[INFO 2021/04/20 03:32:21 u2.py:834] U2 Encoder type: conformer\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n",
"encoder.embed.conv.0.bias | [256] | 256 | True\n",
"encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n",
"encoder.embed.conv.2.bias | [256] | 256 | True\n",
"encoder.embed.out.0.weight | [4864, 256] | 1245184 | True\n",
"encoder.embed.out.0.bias | [256] | 256 | True\n",
"encoder.after_norm.weight | [256] | 256 | True\n",
"encoder.after_norm.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.0.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.0.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.0.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.0.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.0.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.0.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.1.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.1.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.1.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.1.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.1.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.1.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.2.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.2.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.2.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.2.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.2.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.2.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.3.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.3.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.3.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.3.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.3.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.3.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.4.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.4.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.4.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.4.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.4.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.4.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.5.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.5.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.5.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.5.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.5.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.5.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.6.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.6.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.6.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.6.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.6.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.6.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.7.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.7.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.7.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.7.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.7.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.7.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.8.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.8.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.8.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.8.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.8.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.8.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.9.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.9.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.9.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.9.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.9.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.9.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.10.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.10.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.10.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.10.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.10.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.10.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256 | True\n",
"encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n",
"encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n",
"encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n",
"encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n",
"encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n",
"encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256 | True\n",
"encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n",
"encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512 | True\n",
"encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n",
"encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256 | True\n",
"encoder.encoders.11.conv_module.norm.weight | [256] | 256 | True\n",
"encoder.encoders.11.conv_module.norm.bias | [256] | 256 | True\n",
"encoder.encoders.11.conv_module.norm._mean | [256] | 256 | False\n",
"encoder.encoders.11.conv_module.norm._variance | [256] | 256 | False\n",
"encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n",
"encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm_ff.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm_ff.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm_mha.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm_mha.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm_ff_macaron.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm_ff_macaron.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm_conv.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm_conv.bias | [256] | 256 | True\n",
"encoder.encoders.11.norm_final.weight | [256] | 256 | True\n",
"encoder.encoders.11.norm_final.bias | [256] | 256 | True\n",
"encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n",
"encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n",
"decoder.embed.0.weight | [4233, 256] | 1083648 | True\n",
"decoder.after_norm.weight | [256] | 256 | True\n",
"decoder.after_norm.bias | [256] | 256 | True\n",
"decoder.output_layer.weight | [256, 4233] | 1083648 | True\n",
"decoder.output_layer.bias | [4233] | 4233 | True\n",
"decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.0.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.0.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.1.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.1.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.2.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.2.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.3.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.3.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.4.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.4.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n",
"decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n",
"decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n",
"decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm1.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm1.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm2.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm2.bias | [256] | 256 | True\n",
"decoder.decoders.5.norm3.weight | [256] | 256 | True\n",
"decoder.decoders.5.norm3.bias | [256] | 256 | True\n",
"decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n",
"decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n",
"decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n",
"ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n",
"ctc.ctc_lo.bias | [4233] | 4233 | True\n",
"Total parameters: 687.0, 49355282.0 elements.\n"
]
}
],
"source": [
"conf_str='examples/aishell/s1/conf/conformer.yaml'\n",
"cfg = CN().load_cfg(open(conf_str))\n",
"cfg.model.input_dim = 80\n",
"cfg.model.output_dim = 4233\n",
"cfg.model.cmvn_file = \"/workspace/wenet/examples/aishell/s0/raw_wav/train/global_cmvn\"\n",
"cfg.model.cmvn_file_type = 'json'\n",
"cfg.freeze()\n",
"\n",
"model = U2Model(cfg.model)\n",
"print_params(model)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "sapphire-agent",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoder.global_cmvn.mean | [80] | 80\n",
"encoder.global_cmvn.istd | [80] | 80\n",
"encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304\n",
"encoder.embed.conv.0.bias | [256] | 256\n",
"encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824\n",
"encoder.embed.conv.2.bias | [256] | 256\n",
"encoder.embed.out.0.weight | [4864, 256] | 1245184\n",
"encoder.embed.out.0.bias | [256] | 256\n",
"encoder.after_norm.weight | [256] | 256\n",
"encoder.after_norm.bias | [256] | 256\n",
"encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.0.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.0.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.0.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.0.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.0.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.0.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.0.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.0.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.0.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.0.norm_ff.weight | [256] | 256\n",
"encoder.encoders.0.norm_ff.bias | [256] | 256\n",
"encoder.encoders.0.norm_mha.weight | [256] | 256\n",
"encoder.encoders.0.norm_mha.bias | [256] | 256\n",
"encoder.encoders.0.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.0.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.0.norm_conv.weight | [256] | 256\n",
"encoder.encoders.0.norm_conv.bias | [256] | 256\n",
"encoder.encoders.0.norm_final.weight | [256] | 256\n",
"encoder.encoders.0.norm_final.bias | [256] | 256\n",
"encoder.encoders.0.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.0.concat_linear.bias | [256] | 256\n",
"encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.1.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.1.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.1.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.1.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.1.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.1.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.1.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.1.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.1.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.1.norm_ff.weight | [256] | 256\n",
"encoder.encoders.1.norm_ff.bias | [256] | 256\n",
"encoder.encoders.1.norm_mha.weight | [256] | 256\n",
"encoder.encoders.1.norm_mha.bias | [256] | 256\n",
"encoder.encoders.1.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.1.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.1.norm_conv.weight | [256] | 256\n",
"encoder.encoders.1.norm_conv.bias | [256] | 256\n",
"encoder.encoders.1.norm_final.weight | [256] | 256\n",
"encoder.encoders.1.norm_final.bias | [256] | 256\n",
"encoder.encoders.1.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.1.concat_linear.bias | [256] | 256\n",
"encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.2.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.2.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.2.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.2.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.2.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.2.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.2.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.2.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.2.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.2.norm_ff.weight | [256] | 256\n",
"encoder.encoders.2.norm_ff.bias | [256] | 256\n",
"encoder.encoders.2.norm_mha.weight | [256] | 256\n",
"encoder.encoders.2.norm_mha.bias | [256] | 256\n",
"encoder.encoders.2.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.2.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.2.norm_conv.weight | [256] | 256\n",
"encoder.encoders.2.norm_conv.bias | [256] | 256\n",
"encoder.encoders.2.norm_final.weight | [256] | 256\n",
"encoder.encoders.2.norm_final.bias | [256] | 256\n",
"encoder.encoders.2.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.2.concat_linear.bias | [256] | 256\n",
"encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.3.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.3.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.3.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.3.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.3.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.3.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.3.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.3.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.3.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.3.norm_ff.weight | [256] | 256\n",
"encoder.encoders.3.norm_ff.bias | [256] | 256\n",
"encoder.encoders.3.norm_mha.weight | [256] | 256\n",
"encoder.encoders.3.norm_mha.bias | [256] | 256\n",
"encoder.encoders.3.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.3.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.3.norm_conv.weight | [256] | 256\n",
"encoder.encoders.3.norm_conv.bias | [256] | 256\n",
"encoder.encoders.3.norm_final.weight | [256] | 256\n",
"encoder.encoders.3.norm_final.bias | [256] | 256\n",
"encoder.encoders.3.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.3.concat_linear.bias | [256] | 256\n",
"encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.4.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.4.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.4.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.4.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.4.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.4.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.4.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.4.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.4.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.4.norm_ff.weight | [256] | 256\n",
"encoder.encoders.4.norm_ff.bias | [256] | 256\n",
"encoder.encoders.4.norm_mha.weight | [256] | 256\n",
"encoder.encoders.4.norm_mha.bias | [256] | 256\n",
"encoder.encoders.4.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.4.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.4.norm_conv.weight | [256] | 256\n",
"encoder.encoders.4.norm_conv.bias | [256] | 256\n",
"encoder.encoders.4.norm_final.weight | [256] | 256\n",
"encoder.encoders.4.norm_final.bias | [256] | 256\n",
"encoder.encoders.4.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.4.concat_linear.bias | [256] | 256\n",
"encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.5.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.5.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.5.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.5.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.5.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.5.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.5.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.5.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.5.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.5.norm_ff.weight | [256] | 256\n",
"encoder.encoders.5.norm_ff.bias | [256] | 256\n",
"encoder.encoders.5.norm_mha.weight | [256] | 256\n",
"encoder.encoders.5.norm_mha.bias | [256] | 256\n",
"encoder.encoders.5.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.5.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.5.norm_conv.weight | [256] | 256\n",
"encoder.encoders.5.norm_conv.bias | [256] | 256\n",
"encoder.encoders.5.norm_final.weight | [256] | 256\n",
"encoder.encoders.5.norm_final.bias | [256] | 256\n",
"encoder.encoders.5.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.5.concat_linear.bias | [256] | 256\n",
"encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.6.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.6.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.6.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.6.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.6.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.6.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.6.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.6.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.6.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.6.norm_ff.weight | [256] | 256\n",
"encoder.encoders.6.norm_ff.bias | [256] | 256\n",
"encoder.encoders.6.norm_mha.weight | [256] | 256\n",
"encoder.encoders.6.norm_mha.bias | [256] | 256\n",
"encoder.encoders.6.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.6.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.6.norm_conv.weight | [256] | 256\n",
"encoder.encoders.6.norm_conv.bias | [256] | 256\n",
"encoder.encoders.6.norm_final.weight | [256] | 256\n",
"encoder.encoders.6.norm_final.bias | [256] | 256\n",
"encoder.encoders.6.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.6.concat_linear.bias | [256] | 256\n",
"encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.7.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.7.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.7.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.7.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.7.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.7.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.7.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.7.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.7.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.7.norm_ff.weight | [256] | 256\n",
"encoder.encoders.7.norm_ff.bias | [256] | 256\n",
"encoder.encoders.7.norm_mha.weight | [256] | 256\n",
"encoder.encoders.7.norm_mha.bias | [256] | 256\n",
"encoder.encoders.7.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.7.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.7.norm_conv.weight | [256] | 256\n",
"encoder.encoders.7.norm_conv.bias | [256] | 256\n",
"encoder.encoders.7.norm_final.weight | [256] | 256\n",
"encoder.encoders.7.norm_final.bias | [256] | 256\n",
"encoder.encoders.7.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.7.concat_linear.bias | [256] | 256\n",
"encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.8.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.8.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.8.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.8.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.8.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.8.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.8.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.8.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.8.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.8.norm_ff.weight | [256] | 256\n",
"encoder.encoders.8.norm_ff.bias | [256] | 256\n",
"encoder.encoders.8.norm_mha.weight | [256] | 256\n",
"encoder.encoders.8.norm_mha.bias | [256] | 256\n",
"encoder.encoders.8.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.8.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.8.norm_conv.weight | [256] | 256\n",
"encoder.encoders.8.norm_conv.bias | [256] | 256\n",
"encoder.encoders.8.norm_final.weight | [256] | 256\n",
"encoder.encoders.8.norm_final.bias | [256] | 256\n",
"encoder.encoders.8.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.8.concat_linear.bias | [256] | 256\n",
"encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.9.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.9.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.9.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.9.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.9.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.9.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.9.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.9.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.9.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.9.norm_ff.weight | [256] | 256\n",
"encoder.encoders.9.norm_ff.bias | [256] | 256\n",
"encoder.encoders.9.norm_mha.weight | [256] | 256\n",
"encoder.encoders.9.norm_mha.bias | [256] | 256\n",
"encoder.encoders.9.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.9.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.9.norm_conv.weight | [256] | 256\n",
"encoder.encoders.9.norm_conv.bias | [256] | 256\n",
"encoder.encoders.9.norm_final.weight | [256] | 256\n",
"encoder.encoders.9.norm_final.bias | [256] | 256\n",
"encoder.encoders.9.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.9.concat_linear.bias | [256] | 256\n",
"encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.10.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.10.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.10.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.10.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.10.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.10.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.10.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.10.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.10.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.10.norm_ff.weight | [256] | 256\n",
"encoder.encoders.10.norm_ff.bias | [256] | 256\n",
"encoder.encoders.10.norm_mha.weight | [256] | 256\n",
"encoder.encoders.10.norm_mha.bias | [256] | 256\n",
"encoder.encoders.10.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.10.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.10.norm_conv.weight | [256] | 256\n",
"encoder.encoders.10.norm_conv.bias | [256] | 256\n",
"encoder.encoders.10.norm_final.weight | [256] | 256\n",
"encoder.encoders.10.norm_final.bias | [256] | 256\n",
"encoder.encoders.10.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.10.concat_linear.bias | [256] | 256\n",
"encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256\n",
"encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256\n",
"encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536\n",
"encoder.encoders.11.self_attn.linear_q.bias | [256] | 256\n",
"encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536\n",
"encoder.encoders.11.self_attn.linear_k.bias | [256] | 256\n",
"encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536\n",
"encoder.encoders.11.self_attn.linear_v.bias | [256] | 256\n",
"encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536\n",
"encoder.encoders.11.self_attn.linear_out.bias | [256] | 256\n",
"encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536\n",
"encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048\n",
"encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.11.feed_forward.w_2.bias | [256] | 256\n",
"encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n",
"encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048\n",
"encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n",
"encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256\n",
"encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n",
"encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512\n",
"encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n",
"encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256\n",
"encoder.encoders.11.conv_module.norm.weight | [256] | 256\n",
"encoder.encoders.11.conv_module.norm.bias | [256] | 256\n",
"encoder.encoders.11.conv_module.norm._mean | [256] | 256\n",
"encoder.encoders.11.conv_module.norm._variance | [256] | 256\n",
"encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n",
"encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256\n",
"encoder.encoders.11.norm_ff.weight | [256] | 256\n",
"encoder.encoders.11.norm_ff.bias | [256] | 256\n",
"encoder.encoders.11.norm_mha.weight | [256] | 256\n",
"encoder.encoders.11.norm_mha.bias | [256] | 256\n",
"encoder.encoders.11.norm_ff_macaron.weight | [256] | 256\n",
"encoder.encoders.11.norm_ff_macaron.bias | [256] | 256\n",
"encoder.encoders.11.norm_conv.weight | [256] | 256\n",
"encoder.encoders.11.norm_conv.bias | [256] | 256\n",
"encoder.encoders.11.norm_final.weight | [256] | 256\n",
"encoder.encoders.11.norm_final.bias | [256] | 256\n",
"encoder.encoders.11.concat_linear.weight | [512, 256] | 131072\n",
"encoder.encoders.11.concat_linear.bias | [256] | 256\n",
"decoder.embed.0.weight | [4233, 256] | 1083648\n",
"decoder.after_norm.weight | [256] | 256\n",
"decoder.after_norm.bias | [256] | 256\n",
"decoder.output_layer.weight | [256, 4233] | 1083648\n",
"decoder.output_layer.bias | [4233] | 4233\n",
"decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.0.self_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.0.self_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.0.self_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.0.self_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.0.src_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.0.src_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.0.src_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.0.src_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048\n",
"decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"decoder.decoders.0.feed_forward.w_2.bias | [256] | 256\n",
"decoder.decoders.0.norm1.weight | [256] | 256\n",
"decoder.decoders.0.norm1.bias | [256] | 256\n",
"decoder.decoders.0.norm2.weight | [256] | 256\n",
"decoder.decoders.0.norm2.bias | [256] | 256\n",
"decoder.decoders.0.norm3.weight | [256] | 256\n",
"decoder.decoders.0.norm3.bias | [256] | 256\n",
"decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072\n",
"decoder.decoders.0.concat_linear1.bias | [256] | 256\n",
"decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072\n",
"decoder.decoders.0.concat_linear2.bias | [256] | 256\n",
"decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.1.self_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.1.self_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.1.self_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.1.self_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.1.src_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.1.src_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.1.src_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.1.src_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048\n",
"decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"decoder.decoders.1.feed_forward.w_2.bias | [256] | 256\n",
"decoder.decoders.1.norm1.weight | [256] | 256\n",
"decoder.decoders.1.norm1.bias | [256] | 256\n",
"decoder.decoders.1.norm2.weight | [256] | 256\n",
"decoder.decoders.1.norm2.bias | [256] | 256\n",
"decoder.decoders.1.norm3.weight | [256] | 256\n",
"decoder.decoders.1.norm3.bias | [256] | 256\n",
"decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072\n",
"decoder.decoders.1.concat_linear1.bias | [256] | 256\n",
"decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072\n",
"decoder.decoders.1.concat_linear2.bias | [256] | 256\n",
"decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.2.self_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.2.self_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.2.self_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.2.self_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.2.src_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.2.src_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.2.src_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.2.src_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048\n",
"decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"decoder.decoders.2.feed_forward.w_2.bias | [256] | 256\n",
"decoder.decoders.2.norm1.weight | [256] | 256\n",
"decoder.decoders.2.norm1.bias | [256] | 256\n",
"decoder.decoders.2.norm2.weight | [256] | 256\n",
"decoder.decoders.2.norm2.bias | [256] | 256\n",
"decoder.decoders.2.norm3.weight | [256] | 256\n",
"decoder.decoders.2.norm3.bias | [256] | 256\n",
"decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072\n",
"decoder.decoders.2.concat_linear1.bias | [256] | 256\n",
"decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072\n",
"decoder.decoders.2.concat_linear2.bias | [256] | 256\n",
"decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.3.self_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.3.self_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.3.self_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.3.self_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.3.src_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.3.src_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.3.src_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.3.src_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048\n",
"decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"decoder.decoders.3.feed_forward.w_2.bias | [256] | 256\n",
"decoder.decoders.3.norm1.weight | [256] | 256\n",
"decoder.decoders.3.norm1.bias | [256] | 256\n",
"decoder.decoders.3.norm2.weight | [256] | 256\n",
"decoder.decoders.3.norm2.bias | [256] | 256\n",
"decoder.decoders.3.norm3.weight | [256] | 256\n",
"decoder.decoders.3.norm3.bias | [256] | 256\n",
"decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072\n",
"decoder.decoders.3.concat_linear1.bias | [256] | 256\n",
"decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072\n",
"decoder.decoders.3.concat_linear2.bias | [256] | 256\n",
"decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.4.self_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.4.self_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.4.self_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.4.self_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.4.src_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.4.src_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.4.src_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.4.src_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048\n",
"decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"decoder.decoders.4.feed_forward.w_2.bias | [256] | 256\n",
"decoder.decoders.4.norm1.weight | [256] | 256\n",
"decoder.decoders.4.norm1.bias | [256] | 256\n",
"decoder.decoders.4.norm2.weight | [256] | 256\n",
"decoder.decoders.4.norm2.bias | [256] | 256\n",
"decoder.decoders.4.norm3.weight | [256] | 256\n",
"decoder.decoders.4.norm3.bias | [256] | 256\n",
"decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072\n",
"decoder.decoders.4.concat_linear1.bias | [256] | 256\n",
"decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072\n",
"decoder.decoders.4.concat_linear2.bias | [256] | 256\n",
"decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.5.self_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.5.self_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.5.self_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.5.self_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536\n",
"decoder.decoders.5.src_attn.linear_q.bias | [256] | 256\n",
"decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536\n",
"decoder.decoders.5.src_attn.linear_k.bias | [256] | 256\n",
"decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536\n",
"decoder.decoders.5.src_attn.linear_v.bias | [256] | 256\n",
"decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536\n",
"decoder.decoders.5.src_attn.linear_out.bias | [256] | 256\n",
"decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n",
"decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048\n",
"decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n",
"decoder.decoders.5.feed_forward.w_2.bias | [256] | 256\n",
"decoder.decoders.5.norm1.weight | [256] | 256\n",
"decoder.decoders.5.norm1.bias | [256] | 256\n",
"decoder.decoders.5.norm2.weight | [256] | 256\n",
"decoder.decoders.5.norm2.bias | [256] | 256\n",
"decoder.decoders.5.norm3.weight | [256] | 256\n",
"decoder.decoders.5.norm3.bias | [256] | 256\n",
"decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072\n",
"decoder.decoders.5.concat_linear1.bias | [256] | 256\n",
"decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072\n",
"decoder.decoders.5.concat_linear2.bias | [256] | 256\n",
"ctc.ctc_lo.weight | [256, 4233] | 1083648\n",
"ctc.ctc_lo.bias | [4233] | 4233\n",
"Total parameters: 689, 49355442 elements.\n"
]
}
],
"source": [
"summary(model)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ruled-invitation",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"U2Model(\n",
" (encoder): ConformerEncoder(\n",
" (global_cmvn): GlobalCMVN()\n",
" (embed): Conv2dSubsampling4(\n",
" (pos_enc): RelPositionalEncoding(\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" )\n",
" (conv): Sequential(\n",
" (0): Conv2D(1, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n",
" (1): ReLU()\n",
" (2): Conv2D(256, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n",
" (3): ReLU()\n",
" )\n",
" (out): Sequential(\n",
" (0): Linear(in_features=4864, out_features=256, dtype=float32)\n",
" )\n",
" )\n",
" (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (encoders): LayerList(\n",
" (0): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (1): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (2): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (3): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (4): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (5): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (6): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (7): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (8): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (9): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (10): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (11): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n",
" (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n",
" (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n",
" (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" )\n",
" )\n",
" (decoder): TransformerDecoder(\n",
" (embed): Sequential(\n",
" (0): Embedding(4233, 256, sparse=False)\n",
" (1): PositionalEncoding(\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" )\n",
" )\n",
" (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (output_layer): Linear(in_features=256, out_features=4233, dtype=float32)\n",
" (decoders): LayerList(\n",
" (0): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (1): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (2): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (3): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (4): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" (5): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n",
" (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n",
" )\n",
" (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n",
" (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n",
" )\n",
" )\n",
" )\n",
" (ctc): CTCDecoder(\n",
" (ctc_lo): Linear(in_features=256, out_features=4233, dtype=float32)\n",
" (criterion): CTCLoss(\n",
" (loss): CTCLoss()\n",
" )\n",
" )\n",
" (criterion_att): LabelSmoothingLoss(\n",
" (criterion): KLDivLoss()\n",
" )\n",
")\n"
]
}
],
"source": [
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fossil-means",
"metadata": {},
"outputs": [],
"source": [
"# load feat"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fleet-despite",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"compute_cmvn_loader_test.ipynb encoder.npz\r\n",
"dataloader.ipynb hack_api_test.ipynb\r\n",
"dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n",
"data.npz layer_norm_test.ipynb\r\n",
"decoder.npz Linear_test.ipynb\r\n",
"enc_0_ff_out.npz mask_and_masked_fill_test.ipynb\r\n",
"enc_0_norm_ff.npz model.npz\r\n",
"enc_0.npz position_embeding_check.ipynb\r\n",
"enc_0_selattn_out.npz python_test.ipynb\r\n",
"enc_2.npz train_test.ipynb\r\n",
"enc_all.npz u2_model.ipynb\r\n",
"enc_embed.npz\r\n"
]
}
],
"source": [
"%ls .notebook"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "abroad-oracle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['BAC009S0739W0246' 'BAC009S0727W0424' 'BAC009S0753W0412'\n",
" 'BAC009S0756W0206' 'BAC009S0740W0414' 'BAC009S0728W0426'\n",
" 'BAC009S0739W0214' 'BAC009S0753W0423' 'BAC009S0734W0201'\n",
" 'BAC009S0740W0427' 'BAC009S0730W0423' 'BAC009S0728W0367'\n",
" 'BAC009S0730W0418' 'BAC009S0727W0157' 'BAC009S0749W0409'\n",
" 'BAC009S0727W0418']\n",
"(16, 207, 80)\n",
"[[[ 8.994624 9.538309 9.191589 ... 10.507416 9.563305 8.256403 ]\n",
" [ 9.798841 10.405224 9.26511 ... 10.251211 9.543982 8.873768 ]\n",
" [10.6890745 10.395469 8.053548 ... 9.906749 10.064903 8.050915 ]\n",
" ...\n",
" [ 9.217986 9.65069 8.505259 ... 9.687183 8.742463 7.9865475]\n",
" [10.129122 9.935194 9.37982 ... 9.563894 9.825992 8.979543 ]\n",
" [ 9.095531 7.1338377 9.468001 ... 9.472748 9.021235 7.447914 ]]\n",
"\n",
" [[11.430976 10.671858 6.0841026 ... 9.382682 8.729745 7.5315614]\n",
" [ 9.731717 7.8104815 7.5714607 ... 10.043035 9.243595 7.3540792]\n",
" [10.65017 10.600604 8.467784 ... 9.281448 9.186885 8.070343 ]\n",
" ...\n",
" [ 9.096987 9.2637 8.075275 ... 8.431845 8.370505 8.002926 ]\n",
" [10.461651 10.147784 6.7693496 ... 9.779426 9.577453 8.080652 ]\n",
" [ 7.794432 5.621059 7.9750648 ... 9.997245 9.849678 8.031287 ]]\n",
"\n",
" [[ 7.3455667 7.896357 7.5795946 ... 11.631024 10.451254 9.123633 ]\n",
" [ 8.628678 8.4630575 7.499242 ... 12.415986 10.975749 8.9425745]\n",
" [ 9.831394 10.2812805 8.97241 ... 12.1386795 10.40175 9.005517 ]\n",
" ...\n",
" [ 7.089641 7.405548 6.8142557 ... 9.325196 9.273162 8.353427 ]\n",
" [ 0. 0. 0. ... 0. 0. 0. ]\n",
" [ 0. 0. 0. ... 0. 0. 0. ]]\n",
"\n",
" ...\n",
"\n",
" [[10.933237 10.464394 7.7202725 ... 10.348816 9.302338 7.1553144]\n",
" [10.449866 9.907033 9.029272 ... 9.952465 9.414051 7.559279 ]\n",
" [10.487655 9.81259 9.895244 ... 9.58662 9.341254 7.7849016]\n",
" ...\n",
" [ 0. 0. 0. ... 0. 0. 0. ]\n",
" [ 0. 0. 0. ... 0. 0. 0. ]\n",
" [ 0. 0. 0. ... 0. 0. 0. ]]\n",
"\n",
" [[ 9.944384 9.585867 8.220328 ... 11.588647 11.045029 8.817075 ]\n",
" [ 7.678356 8.322397 7.533047 ... 11.055085 10.535685 9.27465 ]\n",
" [ 8.626197 9.675917 9.841045 ... 11.378827 10.922112 8.991444 ]\n",
" ...\n",
" [ 0. 0. 0. ... 0. 0. 0. ]\n",
" [ 0. 0. 0. ... 0. 0. 0. ]\n",
" [ 0. 0. 0. ... 0. 0. 0. ]]\n",
"\n",
" [[ 8.107938 7.759043 6.710301 ... 12.650573 11.466156 11.061517 ]\n",
" [11.380332 11.222007 8.658889 ... 12.810616 12.222216 11.689288 ]\n",
" [10.677676 9.920579 8.046089 ... 13.572894 12.5624075 11.155033 ]\n",
" ...\n",
" [ 0. 0. 0. ... 0. 0. 0. ]\n",
" [ 0. 0. 0. ... 0. 0. 0. ]\n",
" [ 0. 0. 0. ... 0. 0. 0. ]]]\n",
"[207 207 205 205 203 203 198 197 195 188 186 186 185 180 166 163]\n",
"[[2995 3116 1209 565 -1 -1]\n",
" [ 236 1176 331 66 3925 4077]\n",
" [2693 524 234 1145 366 -1]\n",
" [3875 4211 3062 700 -1 -1]\n",
" [ 272 987 1134 494 2959 -1]\n",
" [1936 3715 120 2553 2695 2710]\n",
" [ 25 1149 3930 -1 -1 -1]\n",
" [1753 1778 1237 482 3925 110]\n",
" [3703 2 565 3827 -1 -1]\n",
" [1150 2734 10 2478 3490 -1]\n",
" [ 426 811 95 489 144 -1]\n",
" [2313 2006 489 975 -1 -1]\n",
" [3702 3414 205 1488 2966 1347]\n",
" [ 70 1741 702 1666 -1 -1]\n",
" [ 703 1778 1030 849 -1 -1]\n",
" [ 814 1674 115 3827 -1 -1]]\n",
"[4 6 5 4 5 6 3 6 4 5 5 4 6 4 4 4]\n"
]
}
],
"source": [
"data = np.load('.notebook/data.npz', allow_pickle=True)\n",
"keys=data['keys']\n",
"feat=data['feat']\n",
"feat_len=data['feat_len']\n",
"text=data['text']\n",
"text_len=data['text_len']\n",
"print(keys)\n",
"print(feat.shape)\n",
"print(feat)\n",
"print(feat_len)\n",
"print(text)\n",
"print(text_len)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "false-instrument",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"id": "arctic-proxy",
"metadata": {},
"outputs": [],
"source": [
"# ['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n",
"# torch.Size([16, 207, 80])\n",
"# tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n",
"# [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n",
"# [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n",
"# ...,\n",
"# [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n",
"# [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n",
"# [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n",
"\n",
"# [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n",
"# [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n",
"# [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n",
"# ...,\n",
"# [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n",
"# [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n",
"# [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n",
"\n",
"# [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n",
"# [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n",
"# [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n",
"# ...,\n",
"# [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n",
"\n",
"# ...,\n",
"\n",
"# [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n",
"# [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n",
"# [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n",
"# ...,\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n",
"\n",
"# [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n",
"# [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n",
"# [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n",
"# ...,\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n",
"\n",
"# [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n",
"# [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n",
"# [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n",
"# ...,\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
"# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n",
"# tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n",
"# 166, 163], dtype=torch.int32)\n",
"# tensor([[2995, 3116, 1209, 565, -1, -1],\n",
"# [ 236, 1176, 331, 66, 3925, 4077],\n",
"# [2693, 524, 234, 1145, 366, -1],\n",
"# [3875, 4211, 3062, 700, -1, -1],\n",
"# [ 272, 987, 1134, 494, 2959, -1],\n",
"# [1936, 3715, 120, 2553, 2695, 2710],\n",
"# [ 25, 1149, 3930, -1, -1, -1],\n",
"# [1753, 1778, 1237, 482, 3925, 110],\n",
"# [3703, 2, 565, 3827, -1, -1],\n",
"# [1150, 2734, 10, 2478, 3490, -1],\n",
"# [ 426, 811, 95, 489, 144, -1],\n",
"# [2313, 2006, 489, 975, -1, -1],\n",
"# [3702, 3414, 205, 1488, 2966, 1347],\n",
"# [ 70, 1741, 702, 1666, -1, -1],\n",
"# [ 703, 1778, 1030, 849, -1, -1],\n",
"# [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n",
"# tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "seasonal-switch",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
"id": "defined-brooks",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"compute_cmvn_loader_test.ipynb\t encoder.npz\r\n",
"dataloader.ipynb\t\t hack_api_test.ipynb\r\n",
"dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n",
"data.npz\t\t\t layer_norm_test.ipynb\r\n",
"decoder.npz\t\t\t Linear_test.ipynb\r\n",
"enc_0_ff_out.npz\t\t mask_and_masked_fill_test.ipynb\r\n",
"enc_0_norm_ff.npz\t\t model.npz\r\n",
"enc_0.npz\t\t\t position_embeding_check.ipynb\r\n",
"enc_0_selattn_out.npz\t\t python_test.ipynb\r\n",
"enc_2.npz\t\t\t train_test.ipynb\r\n",
"enc_all.npz\t\t\t u2_model.ipynb\r\n",
"enc_embed.npz\r\n"
]
}
],
"source": [
"# load model param\n",
"!ls .notebook\n",
"data = np.load('.notebook/model.npz', allow_pickle=True)\n",
"state_dict = data['state'].item()\n",
"\n",
"for key, _ in model.state_dict().items():\n",
" if key not in state_dict:\n",
" print(f\"{key} not find.\")\n",
"\n",
"model.set_state_dict(state_dict)\n",
"\n",
"now_state_dict = model.state_dict()\n",
"for key, value in now_state_dict.items():\n",
" if not np.allclose(value.numpy(), state_dict[key]):\n",
" print(key)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "exempt-viewer",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 11,
"id": "confident-piano",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/framework.py:687: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" elif dtype == np.bool:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [142.48880005]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [41.84146118]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [377.33258057])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:238: UserWarning: The dtype of left and right variables are not the same, left dtype is VarType.FP32, but right dtype is VarType.INT32, the right dtype will convert to VarType.FP32\n",
" format(lhs_dtype, rhs_dtype, lhs_dtype))\n"
]
}
],
"source": [
"# compute loss\n",
"import paddle\n",
"feat=paddle.to_tensor(feat)\n",
"feat_len=paddle.to_tensor(feat_len, dtype='int64')\n",
"text=paddle.to_tensor(text, dtype='int64')\n",
"text_len=paddle.to_tensor(text_len, dtype='int64')\n",
"\n",
"model.eval()\n",
"total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n",
" text, text_len)\n",
"print(total_loss, attention_loss, ctc_loss )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "better-senator",
"metadata": {},
"outputs": [],
"source": [
"# tensor(142.4888, device='cuda:0', grad_fn=<AddBackward0>) \n",
"# tensor(41.8415, device='cuda:0', grad_fn=<DivBackward0>) \n",
"# tensor(377.3326, device='cuda:0', grad_fn=<DivBackward0>)\n",
"# 142.4888 41.84146 377.33258"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "related-banking",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 13,
"id": "olympic-problem",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[16, 51, 256]\n",
"[16, 1, 51]\n",
"Tensor(shape=[51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[-0.70194179, 0.56254166, 0.68803459, ..., 1.12373221, 0.78039235, 1.13693869],\n",
" [-0.77877808, 0.39126658, 0.71887815, ..., 1.25188220, 0.88616788, 1.31734526],\n",
" [-0.95908946, 0.63460249, 0.87671334, ..., 0.98183727, 0.74401081, 1.29032660],\n",
" ...,\n",
" [-1.07322502, 0.67236906, 0.92303109, ..., 0.90754563, 0.81767166, 1.32396567],\n",
" [-1.16541159, 0.68199694, 0.69394493, ..., 1.22383487, 0.80282891, 1.45065081],\n",
" [-1.27320945, 0.71458030, 0.75819558, ..., 0.94154912, 0.87748396, 1.26230514]])\n"
]
}
],
"source": [
"# ecnoder\n",
"encoder_out, encoder_mask = model.encoder(feat, feat_len)\n",
"print(encoder_out.shape)\n",
"print(encoder_mask.shape)\n",
"print(encoder_out[0])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "shaped-alaska",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"deepspeech examples README_cn.md\tsetup.sh tools\r\n",
"docs\t LICENSE README.md\t\ttests\t utils\r\n",
"env.sh\t log requirements.txt\tthird_party\r\n"
]
}
],
"source": [
"!ls\n",
"data = np.load('.notebook/encoder.npz', allow_pickle=True)\n",
"torch_mask = data['mask']\n",
"torch_encoder_out = data['out']"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "federal-rover",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"None\n"
]
}
],
"source": [
"print(np.testing.assert_equal(torch_mask, encoder_mask.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "regulated-interstate",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n",
"[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n",
" 1.1369387 ]\n",
" [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n",
" 1.3173454 ]\n",
" [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n",
" 1.2903274 ]\n",
" ...\n",
" [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n",
" 1.3239657 ]\n",
" [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n",
" 1.4506509 ]\n",
" [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n",
" 1.2623053 ]]\n",
"----\n",
"[[-0.7019418 0.56254166 0.6880346 ... 1.1237322 0.78039235\n",
" 1.1369387 ]\n",
" [-0.7787781 0.39126658 0.71887815 ... 1.2518822 0.8861679\n",
" 1.3173453 ]\n",
" [-0.95908946 0.6346025 0.87671334 ... 0.9818373 0.7440108\n",
" 1.2903266 ]\n",
" ...\n",
" [-1.073225 0.67236906 0.9230311 ... 0.9075456 0.81767166\n",
" 1.3239657 ]\n",
" [-1.1654116 0.68199694 0.69394493 ... 1.2238349 0.8028289\n",
" 1.4506508 ]\n",
" [-1.2732095 0.7145803 0.7581956 ... 0.9415491 0.87748396\n",
" 1.2623051 ]]\n",
"True\n",
"False\n"
]
}
],
"source": [
"print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n",
"print(torch_encoder_out[0])\n",
"print(\"----\")\n",
"print(encoder_out.numpy()[0])\n",
"print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5, rtol=1e-6))\n",
"print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6, rtol=1e-6))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "proof-scheduling",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [377.33258057])\n",
"[1.]\n",
"[[ 3.16902876e+00 -1.51763987e-02 4.91095744e-02 ... -2.47971853e-03\n",
" -5.93360700e-03 -7.26609165e-03]\n",
" [-1.74184477e+00 7.75874173e-03 -4.49434854e-02 ... 9.92412097e-04\n",
" 2.46337592e-03 2.31892057e-03]\n",
" [-2.33343339e+00 1.30475955e-02 -2.66557075e-02 ... 2.27532350e-03\n",
" 5.76924905e-03 7.48788286e-03]\n",
" ...\n",
" [-4.30358458e+00 2.46054661e-02 -9.00950655e-02 ... 4.43156436e-03\n",
" 1.16122244e-02 1.44715561e-02]\n",
" [-3.36921120e+00 1.73153952e-02 -6.36872873e-02 ... 3.28363618e-03\n",
" 8.58010259e-03 1.07794888e-02]\n",
" [-6.62045336e+00 3.49955931e-02 -1.23962618e-01 ... 6.36671018e-03\n",
" 1.60814095e-02 2.03891303e-02]]\n",
"[-4.3777819e+00 2.3245810e-02 -9.3339294e-02 ... 4.2569344e-03\n",
" 1.0919910e-02 1.3787797e-02]\n"
]
}
],
"source": [
"from paddle.nn import functional as F\n",
"def ctc_loss(logits,\n",
" labels,\n",
" input_lengths,\n",
" label_lengths,\n",
" blank=0,\n",
" reduction='mean',\n",
" norm_by_times=False):\n",
" loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,\n",
" input_lengths, label_lengths)\n",
" loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])\n",
" assert reduction in ['mean', 'sum', 'none']\n",
" if reduction == 'mean':\n",
" loss_out = paddle.mean(loss_out / label_lengths)\n",
" elif reduction == 'sum':\n",
" loss_out = paddle.sum(loss_out)\n",
" return loss_out\n",
"\n",
"F.ctc_loss = ctc_loss\n",
"\n",
"torch_mask_t = paddle.to_tensor(torch_mask, dtype='int64')\n",
"encoder_out_lens = torch_mask_t.squeeze(1).sum(1)\n",
"loss_ctc = model.ctc(paddle.to_tensor(torch_encoder_out), encoder_out_lens, text, text_len)\n",
"print(loss_ctc)\n",
"loss_ctc.backward()\n",
"print(loss_ctc.grad)\n",
"print(model.ctc.ctc_lo.weight.grad)\n",
"print(model.ctc.ctc_lo.bias.grad)\n",
"\n",
"\n",
"# tensor(377.3326, device='cuda:0', grad_fn=<DivBackward0>)\n",
"# None\n",
"# [[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n",
"# -5.93366381e-03 -7.26613170e-03]\n",
"# [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n",
"# 2.46338220e-03 2.31891591e-03]\n",
"# [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n",
"# 5.76929189e-03 7.48792710e-03]\n",
"# ...\n",
"# [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n",
"# 1.16123557e-02 1.44716976e-02]\n",
"# [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n",
"# 8.58021621e-03 1.07796099e-02]\n",
"# [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n",
"# 1.60815325e-02 2.03892551e-02]]\n",
"# [-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n",
"# 1.0920014e-02 1.3787906e-02]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "enclosed-consolidation",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 18,
"id": "synthetic-hungarian",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [41.84146118]) 0.0\n"
]
}
],
"source": [
"loss_att, acc_att = model._calc_att_loss(paddle.to_tensor(torch_encoder_out), paddle.to_tensor(torch_mask),\n",
" text, text_len)\n",
"print(loss_att, acc_att)\n",
"#tensor(41.8416, device='cuda:0', grad_fn=<DivBackward0>) 0.0"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "indian-sweden",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 202,
"id": "marine-cuisine",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n",
" 1.5034772e-02 4.0337229e-01]\n",
" [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n",
" -1.4352810e-01 -1.0023664e+00]\n",
" [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n",
" -1.1672243e+00 -2.6848501e-01]\n",
" ...\n",
" [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n",
" 4.5470044e-02 -3.7139410e-01]\n",
" [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n",
" -7.9347193e-04 4.2537671e-01]\n",
" [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n",
" -1.0499060e-03 4.2678756e-01]]\n"
]
}
],
"source": [
"data = np.load(\".notebook/decoder.npz\", allow_pickle=True)\n",
"torch_decoder_out = data['decoder_out']\n",
"print(torch_decoder_out[0])"
]
},
{
"cell_type": "code",
"execution_count": 180,
"id": "several-result",
"metadata": {},
"outputs": [],
"source": [
"def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,\n",
" ignore_id: int):\n",
" \"\"\"Add <sos> and <eos> labels.\n",
" Args:\n",
" ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)\n",
" sos (int): index of <sos>\n",
" eos (int): index of <eeos>\n",
" ignore_id (int): index of padding\n",
" Returns:\n",
" ys_in (paddle.Tensor) : (B, Lmax + 1)\n",
" ys_out (paddle.Tensor) : (B, Lmax + 1)\n",
" Examples:\n",
" >>> sos_id = 10\n",
" >>> eos_id = 11\n",
" >>> ignore_id = -1\n",
" >>> ys_pad\n",
" tensor([[ 1, 2, 3, 4, 5],\n",
" [ 4, 5, 6, -1, -1],\n",
" [ 7, 8, 9, -1, -1]], dtype=paddle.int32)\n",
" >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)\n",
" >>> ys_in\n",
" tensor([[10, 1, 2, 3, 4, 5],\n",
" [10, 4, 5, 6, 11, 11],\n",
" [10, 7, 8, 9, 11, 11]])\n",
" >>> ys_out\n",
" tensor([[ 1, 2, 3, 4, 5, 11],\n",
" [ 4, 5, 6, 11, -1, -1],\n",
" [ 7, 8, 9, 11, -1, -1]])\n",
" \"\"\"\n",
" # TODO(Hui Zhang): using comment code, \n",
" #_sos = paddle.to_tensor(\n",
" # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n",
" #_eos = paddle.to_tensor(\n",
" # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n",
" #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n",
" #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]\n",
" #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]\n",
" #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)\n",
" B = ys_pad.size(0)\n",
" _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos\n",
" _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos\n",
" ys_in = paddle.cat([_sos, ys_pad], dim=1)\n",
" mask_pad = (ys_in == ignore_id)\n",
" ys_in = ys_in.masked_fill(mask_pad, eos)\n",
" \n",
"\n",
" ys_out = paddle.cat([ys_pad, _eos], dim=1)\n",
" ys_out = ys_out.masked_fill(mask_pad, eos)\n",
" mask_eos = (ys_out == ignore_id)\n",
" ys_out = ys_out.masked_fill(mask_eos, eos)\n",
" ys_out = ys_out.masked_fill(mask_pad, ignore_id)\n",
" return ys_in, ys_out"
]
},
{
"cell_type": "code",
"execution_count": 181,
"id": "possible-bulgaria",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [[4232, 2995, 3116, 1209, 565 , 4232, 4232],\n",
" [4232, 236 , 1176, 331 , 66 , 3925, 4077],\n",
" [4232, 2693, 524 , 234 , 1145, 366 , 4232],\n",
" [4232, 3875, 4211, 3062, 700 , 4232, 4232],\n",
" [4232, 272 , 987 , 1134, 494 , 2959, 4232],\n",
" [4232, 1936, 3715, 120 , 2553, 2695, 2710],\n",
" [4232, 25 , 1149, 3930, 4232, 4232, 4232],\n",
" [4232, 1753, 1778, 1237, 482 , 3925, 110 ],\n",
" [4232, 3703, 2 , 565 , 3827, 4232, 4232],\n",
" [4232, 1150, 2734, 10 , 2478, 3490, 4232],\n",
" [4232, 426 , 811 , 95 , 489 , 144 , 4232],\n",
" [4232, 2313, 2006, 489 , 975 , 4232, 4232],\n",
" [4232, 3702, 3414, 205 , 1488, 2966, 1347],\n",
" [4232, 70 , 1741, 702 , 1666, 4232, 4232],\n",
" [4232, 703 , 1778, 1030, 849 , 4232, 4232],\n",
" [4232, 814 , 1674, 115 , 3827, 4232, 4232]])\n",
"Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [[2995, 3116, 1209, 565, 4232, -1 , -1 ],\n",
" [ 236, 1176, 331, 66 , 3925, 4077, 4232],\n",
" [2693, 524, 234, 1145, 366, 4232, -1 ],\n",
" [3875, 4211, 3062, 700, 4232, -1 , -1 ],\n",
" [ 272, 987, 1134, 494, 2959, 4232, -1 ],\n",
" [1936, 3715, 120, 2553, 2695, 2710, 4232],\n",
" [ 25 , 1149, 3930, 4232, -1 , -1 , -1 ],\n",
" [1753, 1778, 1237, 482, 3925, 110, 4232],\n",
" [3703, 2 , 565, 3827, 4232, -1 , -1 ],\n",
" [1150, 2734, 10 , 2478, 3490, 4232, -1 ],\n",
" [ 426, 811, 95 , 489, 144, 4232, -1 ],\n",
" [2313, 2006, 489, 975, 4232, -1 , -1 ],\n",
" [3702, 3414, 205, 1488, 2966, 1347, 4232],\n",
" [ 70 , 1741, 702, 1666, 4232, -1 , -1 ],\n",
" [ 703, 1778, 1030, 849, 4232, -1 , -1 ],\n",
" [ 814, 1674, 115, 3827, 4232, -1 , -1 ]])\n"
]
}
],
"source": [
"ys_pad = text\n",
"ys_pad_lens = text_len\n",
"ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n",
" model.ignore_id)\n",
"ys_in_lens = ys_pad_lens + 1\n",
"print(ys_in_pad)\n",
"print(ys_out_pad)"
]
},
{
"cell_type": "code",
"execution_count": 285,
"id": "north-walter",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n",
"True\n",
"False\n",
"[[-3.76389682e-01 -8.22720408e-01 7.42762923e-01 ... 3.42005253e-01\n",
" 1.50350705e-02 4.03372347e-01]\n",
" [-8.73864174e-01 -3.13894272e-01 4.19878662e-01 ... 3.77237231e-01\n",
" -1.43528014e-01 -1.00236630e+00]\n",
" [-4.35050905e-01 3.45046446e-02 -2.87102997e-01 ... 7.72742853e-02\n",
" -1.16722476e+00 -2.68485069e-01]\n",
" ...\n",
" [ 4.24714804e-01 5.88856399e-01 2.02039629e-02 ... 3.74054879e-01\n",
" 4.54700664e-02 -3.71394157e-01]\n",
" [-3.79784584e-01 -8.10841978e-01 7.57250786e-01 ... 2.60389000e-01\n",
" -7.93404877e-04 4.25376773e-01]\n",
" [-3.82798851e-01 -8.12067091e-01 7.49434292e-01 ... 2.61730075e-01\n",
" -1.04988366e-03 4.26787734e-01]]\n",
"---\n",
"[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n",
" 1.5034772e-02 4.0337229e-01]\n",
" [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n",
" -1.4352810e-01 -1.0023664e+00]\n",
" [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n",
" -1.1672243e+00 -2.6848501e-01]\n",
" ...\n",
" [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n",
" 4.5470044e-02 -3.7139410e-01]\n",
" [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n",
" -7.9347193e-04 4.2537671e-01]\n",
" [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n",
" -1.0499060e-03 4.2678756e-01]]\n"
]
}
],
"source": [
"decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n",
" ys_in_lens)\n",
"\n",
"print(np.allclose(decoder_out.numpy(), torch_decoder_out))\n",
"print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-6))\n",
"print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-7))\n",
"print(decoder_out.numpy()[0])\n",
"print('---')\n",
"print(torch_decoder_out[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "armed-cowboy",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "fifty-earth",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "proud-commonwealth",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 183,
"id": "assisted-fortune",
"metadata": {},
"outputs": [],
"source": [
"from paddle import nn\n",
"import paddle\n",
"from paddle.nn import functional as F\n",
"\n",
"class LabelSmoothingLoss(nn.Layer):\n",
"\n",
" def __init__(self,\n",
" size: int,\n",
" padding_idx: int,\n",
" smoothing: float,\n",
" normalize_length: bool=False):\n",
" super().__init__()\n",
" self.size = size\n",
" self.padding_idx = padding_idx\n",
" self.smoothing = smoothing\n",
" self.confidence = 1.0 - smoothing\n",
" self.normalize_length = normalize_length\n",
" self.criterion = nn.KLDivLoss(reduction=\"none\")\n",
"\n",
" def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:\n",
" \"\"\"Compute loss between x and target.\n",
" The model outputs and data labels tensors are flatten to\n",
" (batch*seqlen, class) shape and a mask is applied to the\n",
" padding part which should not be calculated for loss.\n",
" \n",
" Args:\n",
" x (paddle.Tensor): prediction (batch, seqlen, class)\n",
" target (paddle.Tensor):\n",
" target signal masked with self.padding_id (batch, seqlen)\n",
" Returns:\n",
" loss (paddle.Tensor) : The KL loss, scalar float value\n",
" \"\"\"\n",
" B, T, D = paddle.shape(x)\n",
" assert D == self.size\n",
" x = x.reshape((-1, self.size))\n",
" target = target.reshape([-1])\n",
"\n",
" # use zeros_like instead of torch.no_grad() for true_dist,\n",
" # since no_grad() can not be exported by JIT\n",
" true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))\n",
" ignore = target == self.padding_idx # (B,)\n",
" print(self.smoothing / (self.size - 1))\n",
" print(true_dist)\n",
"\n",
" #target = target * (1 - ignore) # avoid -1 index\n",
" target = target.masked_fill(ignore, 0) # avoid -1 index\n",
" \n",
" \n",
" #true_dist += F.one_hot(target, self.size) * self.confidence\n",
" target_mask = F.one_hot(target, self.size)\n",
" true_dist *= (1 - target_mask)\n",
" true_dist += target_mask * self.confidence\n",
" \n",
"\n",
" kl = self.criterion(F.log_softmax(x, axis=1), true_dist)\n",
" \n",
" #TODO(Hui Zhang): sum not support bool type\n",
" #total = len(target) - int(ignore.sum())\n",
" total = len(target) - int(ignore.type_as(target).sum())\n",
" denom = total if self.normalize_length else B\n",
"\n",
" #numer = (kl * (1 - ignore)).sum()\n",
" numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n",
" return numer / denom\n"
]
},
{
"cell_type": "code",
"execution_count": 184,
"id": "weighted-delight",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.3629489603024576e-05\n",
"Tensor(shape=[112, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
" [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
" [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
" ...,\n",
" [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
" [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n",
" [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363]])\n",
"Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [41.84146118])\n",
"VarType.INT64\n"
]
}
],
"source": [
"criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n",
"loss_att = criteron(paddle.to_tensor(torch_decoder_out), ys_out_pad.astype('int64'))\n",
"print(loss_att)\n",
"print(ys_out_pad.dtype)\n",
"# tensor(41.8416, device='cuda:0', grad_fn=<DivBackward0>)"
]
},
{
"cell_type": "code",
"execution_count": 286,
"id": "dress-shelter",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [41.84146118])\n",
"Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [41.84146118])\n",
"4233\n",
"-1\n",
"0.1\n",
"False\n"
]
}
],
"source": [
"decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n",
" ys_in_lens)\n",
"\n",
"loss_att = model.criterion_att(paddle.to_tensor(torch_decoder_out), ys_out_pad)\n",
"print(loss_att)\n",
"\n",
"loss_att = model.criterion_att(decoder_out, ys_out_pad)\n",
"print(loss_att)\n",
"\n",
"print(model.criterion_att.size)\n",
"print(model.criterion_att.padding_idx)\n",
"print(model.criterion_att.smoothing)\n",
"print(model.criterion_att.normalize_length)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "growing-tooth",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "going-hungary",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "naughty-citizenship",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "experimental-emerald",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "adverse-saskatchewan",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 27,
"id": "speaking-shelf",
"metadata": {},
"outputs": [],
"source": [
"from typing import List\n",
"from typing import Optional\n",
"from typing import Tuple\n",
"\n",
"import paddle\n",
"from paddle import nn\n",
"from typeguard import check_argument_types\n",
"\n",
"from deepspeech.modules.activation import get_activation\n",
"from deepspeech.modules.attention import MultiHeadedAttention\n",
"from deepspeech.modules.attention import RelPositionMultiHeadedAttention\n",
"from deepspeech.modules.conformer_convolution import ConvolutionModule\n",
"from deepspeech.modules.embedding import PositionalEncoding\n",
"from deepspeech.modules.embedding import RelPositionalEncoding\n",
"from deepspeech.modules.encoder_layer import ConformerEncoderLayer\n",
"from deepspeech.modules.encoder_layer import TransformerEncoderLayer\n",
"from deepspeech.modules.mask import add_optional_chunk_mask\n",
"from deepspeech.modules.mask import make_non_pad_mask\n",
"from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward\n",
"from deepspeech.modules.subsampling import Conv2dSubsampling4\n",
"from deepspeech.modules.subsampling import Conv2dSubsampling6\n",
"from deepspeech.modules.subsampling import Conv2dSubsampling8\n",
"from deepspeech.modules.subsampling import LinearNoSubsampling\n",
"\n",
"class BaseEncoder(nn.Layer):\n",
" def __init__(\n",
" self,\n",
" input_size: int,\n",
" output_size: int=256,\n",
" attention_heads: int=4,\n",
" linear_units: int=2048,\n",
" num_blocks: int=6,\n",
" dropout_rate: float=0.1,\n",
" positional_dropout_rate: float=0.1,\n",
" attention_dropout_rate: float=0.0,\n",
" input_layer: str=\"conv2d\",\n",
" pos_enc_layer_type: str=\"abs_pos\",\n",
" normalize_before: bool=True,\n",
" concat_after: bool=False,\n",
" static_chunk_size: int=0,\n",
" use_dynamic_chunk: bool=False,\n",
" global_cmvn: paddle.nn.Layer=None,\n",
" use_dynamic_left_chunk: bool=False, ):\n",
" \"\"\"\n",
" Args:\n",
" input_size (int): input dim, d_feature\n",
" output_size (int): dimension of attention, d_model\n",
" attention_heads (int): the number of heads of multi head attention\n",
" linear_units (int): the hidden units number of position-wise feed\n",
" forward\n",
" num_blocks (int): the number of encoder blocks\n",
" dropout_rate (float): dropout rate\n",
" attention_dropout_rate (float): dropout rate in attention\n",
" positional_dropout_rate (float): dropout rate after adding\n",
" positional encoding\n",
" input_layer (str): input layer type.\n",
" optional [linear, conv2d, conv2d6, conv2d8]\n",
" pos_enc_layer_type (str): Encoder positional encoding layer type.\n",
" opitonal [abs_pos, scaled_abs_pos, rel_pos]\n",
" normalize_before (bool):\n",
" True: use layer_norm before each sub-block of a layer.\n",
" False: use layer_norm after each sub-block of a layer.\n",
" concat_after (bool): whether to concat attention layer's input\n",
" and output.\n",
" True: x -> x + linear(concat(x, att(x)))\n",
" False: x -> x + att(x)\n",
" static_chunk_size (int): chunk size for static chunk training and\n",
" decoding\n",
" use_dynamic_chunk (bool): whether use dynamic chunk size for\n",
" training or not, You can only use fixed chunk(chunk_size > 0)\n",
" or dyanmic chunk size(use_dynamic_chunk = True)\n",
" global_cmvn (Optional[paddle.nn.Layer]): Optional GlobalCMVN layer\n",
" use_dynamic_left_chunk (bool): whether use dynamic left chunk in\n",
" dynamic chunk training\n",
" \"\"\"\n",
" assert check_argument_types()\n",
" super().__init__()\n",
" self._output_size = output_size\n",
"\n",
" if pos_enc_layer_type == \"abs_pos\":\n",
" pos_enc_class = PositionalEncoding\n",
" elif pos_enc_layer_type == \"rel_pos\":\n",
" pos_enc_class = RelPositionalEncoding\n",
" else:\n",
" raise ValueError(\"unknown pos_enc_layer: \" + pos_enc_layer_type)\n",
"\n",
" if input_layer == \"linear\":\n",
" subsampling_class = LinearNoSubsampling\n",
" elif input_layer == \"conv2d\":\n",
" subsampling_class = Conv2dSubsampling4\n",
" elif input_layer == \"conv2d6\":\n",
" subsampling_class = Conv2dSubsampling6\n",
" elif input_layer == \"conv2d8\":\n",
" subsampling_class = Conv2dSubsampling8\n",
" else:\n",
" raise ValueError(\"unknown input_layer: \" + input_layer)\n",
"\n",
" self.global_cmvn = global_cmvn\n",
" self.embed = subsampling_class(\n",
" idim=input_size,\n",
" odim=output_size,\n",
" dropout_rate=dropout_rate,\n",
" pos_enc_class=pos_enc_class(\n",
" d_model=output_size, dropout_rate=positional_dropout_rate), )\n",
"\n",
" self.normalize_before = normalize_before\n",
" self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)\n",
" self.static_chunk_size = static_chunk_size\n",
" self.use_dynamic_chunk = use_dynamic_chunk\n",
" self.use_dynamic_left_chunk = use_dynamic_left_chunk\n",
"\n",
" def output_size(self) -> int:\n",
" return self._output_size\n",
"\n",
" def forward(\n",
" self,\n",
" xs: paddle.Tensor,\n",
" xs_lens: paddle.Tensor,\n",
" decoding_chunk_size: int=0,\n",
" num_decoding_left_chunks: int=-1,\n",
" ) -> Tuple[paddle.Tensor, paddle.Tensor]:\n",
" \"\"\"Embed positions in tensor.\n",
" Args:\n",
" xs: padded input tensor (B, L, D)\n",
" xs_lens: input length (B)\n",
" decoding_chunk_size: decoding chunk size for dynamic chunk\n",
" 0: default for training, use random dynamic chunk.\n",
" <0: for decoding, use full chunk.\n",
" >0: for decoding, use fixed chunk size as set.\n",
" num_decoding_left_chunks: number of left chunks, this is for decoding,\n",
" the chunk size is decoding_chunk_size.\n",
" >=0: use num_decoding_left_chunks\n",
" <0: use all left chunks\n",
" Returns:\n",
" encoder output tensor, lens and mask\n",
" \"\"\"\n",
" masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)\n",
"\n",
" if self.global_cmvn is not None:\n",
" xs = self.global_cmvn(xs)\n",
" #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor\n",
" xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)\n",
" #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor\n",
" masks = masks.astype(paddle.bool)\n",
" #TODO(Hui Zhang): mask_pad = ~masks\n",
" mask_pad = masks.logical_not()\n",
" chunk_masks = add_optional_chunk_mask(\n",
" xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,\n",
" decoding_chunk_size, self.static_chunk_size,\n",
" num_decoding_left_chunks)\n",
" for layer in self.encoders:\n",
" xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
" if self.normalize_before:\n",
" xs = self.after_norm(xs)\n",
" # Here we assume the mask is not changed in encoder layers, so just\n",
" # return the masks before encoder layers, and the masks will be used\n",
" # for cross attention with decoder later\n",
" return xs, masks"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "sharp-municipality",
"metadata": {},
"outputs": [],
"source": [
"\n",
"class ConformerEncoder(BaseEncoder):\n",
" \"\"\"Conformer encoder module.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" input_size: int,\n",
" output_size: int=256,\n",
" attention_heads: int=4,\n",
" linear_units: int=2048,\n",
" num_blocks: int=6,\n",
" dropout_rate: float=0.1,\n",
" positional_dropout_rate: float=0.1,\n",
" attention_dropout_rate: float=0.0,\n",
" input_layer: str=\"conv2d\",\n",
" pos_enc_layer_type: str=\"rel_pos\",\n",
" normalize_before: bool=True,\n",
" concat_after: bool=False,\n",
" static_chunk_size: int=0,\n",
" use_dynamic_chunk: bool=False,\n",
" global_cmvn: nn.Layer=None,\n",
" use_dynamic_left_chunk: bool=False,\n",
" positionwise_conv_kernel_size: int=1,\n",
" macaron_style: bool=True,\n",
" selfattention_layer_type: str=\"rel_selfattn\",\n",
" activation_type: str=\"swish\",\n",
" use_cnn_module: bool=True,\n",
" cnn_module_kernel: int=15,\n",
" causal: bool=False,\n",
" cnn_module_norm: str=\"batch_norm\", ):\n",
" \"\"\"Construct ConformerEncoder\n",
" Args:\n",
" input_size to use_dynamic_chunk, see in BaseEncoder\n",
" positionwise_conv_kernel_size (int): Kernel size of positionwise\n",
" conv1d layer.\n",
" macaron_style (bool): Whether to use macaron style for\n",
" positionwise layer.\n",
" selfattention_layer_type (str): Encoder attention layer type,\n",
" the parameter has no effect now, it's just for configure\n",
" compatibility.\n",
" activation_type (str): Encoder activation function type.\n",
" use_cnn_module (bool): Whether to use convolution module.\n",
" cnn_module_kernel (int): Kernel size of convolution module.\n",
" causal (bool): whether to use causal convolution or not.\n",
" cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']\n",
" \"\"\"\n",
" assert check_argument_types()\n",
" super().__init__(input_size, output_size, attention_heads, linear_units,\n",
" num_blocks, dropout_rate, positional_dropout_rate,\n",
" attention_dropout_rate, input_layer,\n",
" pos_enc_layer_type, normalize_before, concat_after,\n",
" static_chunk_size, use_dynamic_chunk, global_cmvn,\n",
" use_dynamic_left_chunk)\n",
" activation = get_activation(activation_type)\n",
"\n",
" # self-attention module definition\n",
" encoder_selfattn_layer = RelPositionMultiHeadedAttention\n",
" encoder_selfattn_layer_args = (attention_heads, output_size,\n",
" attention_dropout_rate)\n",
" # feed-forward module definition\n",
" positionwise_layer = PositionwiseFeedForward\n",
" positionwise_layer_args = (output_size, linear_units, dropout_rate,\n",
" activation)\n",
" # convolution module definition\n",
" convolution_layer = ConvolutionModule\n",
" convolution_layer_args = (output_size, cnn_module_kernel, activation,\n",
" cnn_module_norm, causal)\n",
"\n",
" self.encoders = nn.ModuleList([\n",
" ConformerEncoderLayer(\n",
" size=output_size,\n",
" self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),\n",
" feed_forward=positionwise_layer(*positionwise_layer_args),\n",
" feed_forward_macaron=positionwise_layer(\n",
" *positionwise_layer_args) if macaron_style else None,\n",
" conv_module=convolution_layer(*convolution_layer_args)\n",
" if use_cnn_module else None,\n",
" dropout_rate=dropout_rate,\n",
" normalize_before=normalize_before,\n",
" concat_after=concat_after) for _ in range(num_blocks)\n",
" ])\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "tutorial-syndication",
"metadata": {},
"outputs": [],
"source": [
"from deepspeech.frontend.utility import load_cmvn\n",
"from deepspeech.modules.cmvn import GlobalCMVN\n",
"\n",
"configs=cfg.model\n",
"mean, istd = load_cmvn(configs['cmvn_file'],\n",
" configs['cmvn_file_type'])\n",
"global_cmvn = GlobalCMVN(\n",
" paddle.to_tensor(mean, dtype=paddle.float),\n",
" paddle.to_tensor(istd, dtype=paddle.float))\n",
"\n",
"\n",
"input_dim = configs['input_dim']\n",
"vocab_size = configs['output_dim']\n",
"encoder_type = configs.get('encoder', 'transformer')\n",
" \n",
"encoder = ConformerEncoder(\n",
" input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "fuzzy-register",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"o = global_cmvn(feat)\n",
"o2 = model.encoder.global_cmvn(feat)\n",
"print(np.allclose(o.numpy(), o2.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "explicit-triumph",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "humanitarian-belgium",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "dying-proposal",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "honest-quick",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "bound-cholesterol",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "viral-packaging",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 203,
"id": "balanced-locator",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[16, 1, 207], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
" [[[True , True , True , ..., True , True , True ]],\n",
"\n",
" [[True , True , True , ..., True , True , True ]],\n",
"\n",
" [[True , True , True , ..., True , False, False]],\n",
"\n",
" ...,\n",
"\n",
" [[True , True , True , ..., False, False, False]],\n",
"\n",
" [[True , True , True , ..., False, False, False]],\n",
"\n",
" [[True , True , True , ..., False, False, False]]])\n"
]
}
],
"source": [
"from deepspeech.modules.mask import make_non_pad_mask\n",
"from deepspeech.modules.mask import make_pad_mask\n",
"masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
"print(masks)"
]
},
{
"cell_type": "code",
"execution_count": 204,
"id": "induced-proposition",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[16, 207, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[[-0.53697914, -0.19910523, -0.34997201, ..., -0.82427669, -1.02650309, -0.96300691],\n",
" [-0.04464225, 0.23176001, -0.32538742, ..., -0.90158713, -1.03248465, -0.75986791],\n",
" [ 0.50035292, 0.22691160, -0.73052198, ..., -1.00552964, -0.87123060, -1.03062117],\n",
" ...,\n",
" [-0.40023831, -0.14325078, -0.57947433, ..., -1.07178426, -1.28059900, -1.05180073],\n",
" [ 0.15755332, -0.00184949, -0.28702953, ..., -1.10898709, -0.94518697, -0.72506356],\n",
" [-0.47520429, -1.39415145, -0.25754252, ..., -1.13649082, -1.19430351, -1.22903371]],\n",
"\n",
" [[ 0.95454037, 0.36427975, -1.38908529, ..., -1.16366839, -1.28453600, -1.20151031],\n",
" [-0.08573537, -1.05785275, -0.89172721, ..., -0.96440506, -1.12547100, -1.25990939],\n",
" [ 0.47653601, 0.32886592, -0.59200549, ..., -1.19421589, -1.14302588, -1.02422845],\n",
" ...,\n",
" [-0.47431335, -0.33558893, -0.72325647, ..., -1.45058632, -1.39574063, -1.04641151],\n",
" [ 0.36112556, 0.10380996, -1.15994537, ..., -1.04394984, -1.02212358, -1.02083635],\n",
" [-1.27172923, -2.14601755, -0.75676596, ..., -0.97822225, -0.93785471, -1.03707945]],\n",
"\n",
" [[-1.54652190, -1.01517177, -0.88900733, ..., -0.48522446, -0.75163364, -0.67765164],\n",
" [-0.76100892, -0.73351598, -0.91587651, ..., -0.24835993, -0.58927339, -0.73722762],\n",
" [-0.02471367, 0.17015894, -0.42326337, ..., -0.33203802, -0.76695800, -0.71651691],\n",
" ...,\n",
" [-1.70319796, -1.25910866, -1.14492917, ..., -1.18101490, -1.11631835, -0.93108195],\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n",
"\n",
" ...,\n",
"\n",
" [[ 0.64982772, 0.26116797, -0.84196597, ..., -0.87213463, -1.10728693, -1.32531130],\n",
" [ 0.35391113, -0.01584581, -0.40424931, ..., -0.99173468, -1.07270539, -1.19239008],\n",
" [ 0.37704495, -0.06278508, -0.11467686, ..., -1.10212946, -1.09524000, -1.11815071],\n",
" ...,\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n",
"\n",
" [[ 0.04445776, -0.17546852, -0.67475224, ..., -0.49801198, -0.56782746, -0.77852231],\n",
" [-1.34279025, -0.80342549, -0.90457231, ..., -0.65901577, -0.72549772, -0.62796098],\n",
" [-0.76252806, -0.13071291, -0.13280024, ..., -0.56132573, -0.60587686, -0.72114766],\n",
" ...,\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n",
"\n",
" [[-1.07980299, -1.08341801, -1.17969072, ..., -0.17757270, -0.43746525, -0.04000654],\n",
" [ 0.92353648, 0.63770926, -0.52810186, ..., -0.12927933, -0.20342292, 0.16655664],\n",
" [ 0.49337494, -0.00911332, -0.73301607, ..., 0.10074048, -0.09811471, -0.00923573],\n",
" ...,\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n",
" [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]]])\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"print(xs)"
]
},
{
"cell_type": "code",
"execution_count": 205,
"id": "cutting-julian",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[16, 256, 51, 19], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0.00209083],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0.01194306, 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0.04610471, 0. ],\n",
" [0. , 0. , 0. , ..., 0.00967231, 0.04613467, 0. ]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" ...,\n",
"\n",
" [[0.22816099, 0.24614786, 0.25304127, ..., 0.20401822, 0.23248228, 0.31190544],\n",
" [0.13587360, 0.28877240, 0.27991283, ..., 0.19210319, 0.20346391, 0.19934426],\n",
" [0.25739068, 0.39348233, 0.27877361, ..., 0.27482539, 0.19302306, 0.23810163],\n",
" ...,\n",
" [0.11939213, 0.28473237, 0.33082074, ..., 0.23838061, 0.22104350, 0.23905794],\n",
" [0.17387670, 0.20402060, 0.40263173, ..., 0.24782266, 0.26742202, 0.15426503],\n",
" [0. , 0.29080707, 0.27725950, ..., 0.17539823, 0.18478745, 0.22483408]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.35446781, 0.38861471, 0.39724261, ..., 0.38680089, 0.33568040, 0.34552398],\n",
" [0.41739127, 0.51038563, 0.41729912, ..., 0.33992639, 0.37081629, 0.35109508],\n",
" [0.36116859, 0.40744874, 0.48490953, ..., 0.34848654, 0.32321057, 0.35188958],\n",
" ...,\n",
" [0.23143977, 0.38021481, 0.51526314, ..., 0.36499465, 0.37411752, 0.39986172],\n",
" [0.34678638, 0.40238205, 0.50076538, ..., 0.36184520, 0.31596646, 0.36334658],\n",
" [0.36498138, 0.37943166, 0.51718897, ..., 0.31798238, 0.33656698, 0.34130475]]],\n",
"\n",
"\n",
" [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.01456045, 0.09447514, 0. , ..., 0. , 0. , 0. ],\n",
" [0.01500242, 0.02963220, 0. , ..., 0. , 0. , 0. ],\n",
" [0.03295187, 0. , 0. , ..., 0.04584959, 0.02043908, 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0.04425837],\n",
" [0. , 0. , 0.02556529, ..., 0. , 0.00900441, 0.04908358]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0.11141267, 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" ...,\n",
"\n",
" [[0.33696529, 0.38526866, 0.32900479, ..., 0.28703830, 0.23351061, 0.19004467],\n",
" [0.13575366, 0.35783342, 0.33573425, ..., 0.22081660, 0.15854910, 0.13587447],\n",
" [0.21928655, 0.28900093, 0.28255141, ..., 0.20602837, 0.23927397, 0.21909429],\n",
" ...,\n",
" [0.23291890, 0.39096734, 0.36399242, ..., 0.20598020, 0.25373828, 0.23137446],\n",
" [0.18739152, 0.30793777, 0.30296701, ..., 0.27250600, 0.25191751, 0.20836820],\n",
" [0.22454213, 0.41402060, 0.54082996, ..., 0.31874508, 0.25079906, 0.25938687]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.26456982, 0.49519050, 0.56702250, ..., 0.30954638, 0.35292268, 0.32668519],\n",
" [0.21576807, 0.51833367, 0.49183372, ..., 0.36043224, 0.38523889, 0.36154741],\n",
" [0.20067888, 0.42784205, 0.52817714, ..., 0.31871423, 0.32452232, 0.31036487],\n",
" ...,\n",
" [0.49855131, 0.51001430, 0.52278662, ..., 0.36450142, 0.34338164, 0.33602941],\n",
" [0.41233343, 0.55517823, 0.52827710, ..., 0.40675971, 0.33873138, 0.36724189],\n",
" [0.40820011, 0.46187383, 0.47338152, ..., 0.38690975, 0.36039269, 0.38022059]]],\n",
"\n",
"\n",
" [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0. , 0.00578516, 0. , ..., 0.00748384, 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0.03035110, 0. , 0.00026720],\n",
" [0.00094807, 0. , 0. , ..., 0.00795512, 0. , 0. ],\n",
" ...,\n",
" [0.02032628, 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0.01080076, 0. ],\n",
" [0.18470290, 0. , 0. , ..., 0.05058352, 0.09475817, 0.05914564]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" ...,\n",
"\n",
" [[0.38708323, 0.28021947, 0.35892880, ..., 0.16595127, 0.16031364, 0.21136315],\n",
" [0.15595171, 0.30544323, 0.24666184, ..., 0.22675267, 0.25765014, 0.19682154],\n",
" [0.29517862, 0.41209796, 0.20063159, ..., 0.17595036, 0.22536841, 0.22214051],\n",
" ...,\n",
" [0.24744980, 0.26258564, 0.38654143, ..., 0.23620218, 0.23157144, 0.18514194],\n",
" [0.25714791, 0.29592845, 0.47744542, ..., 0.23545510, 0.25072727, 0.20976165],\n",
" [1.20154655, 0.84644288, 0.73385584, ..., 1.02517247, 0.95309550, 1.00134516]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.45013186, 0.47484034, 0.40540054, ..., 0.19346163, 0.17825794, 0.14776605],\n",
" [0.47545874, 0.48186573, 0.36760187, ..., 0.27809089, 0.32997063, 0.32337096],\n",
" [0.46160024, 0.40050328, 0.39060861, ..., 0.36612910, 0.35242686, 0.29738861],\n",
" ...,\n",
" [0.55148494, 0.51017821, 0.40132499, ..., 0.38948193, 0.35737294, 0.33088297],\n",
" [0.41972569, 0.45475486, 0.45320493, ..., 0.38343129, 0.40125814, 0.36180776],\n",
" [0.34279808, 0.31606171, 0.44701228, ..., 0.21665487, 0.23984617, 0.23903391]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0.04178291, 0. , 0.01580476, ..., 0. , 0.02250817, 0. ],\n",
" [0.04323414, 0.07786420, 0. , ..., 0.01634724, 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.03209178, 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0.13563479, 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" ...,\n",
"\n",
" [[0. , 0.25187218, 0.24979387, ..., 0.24774717, 0.22354351, 0.19149347],\n",
" [0.16540922, 0.19585510, 0.19812922, ..., 0.27344131, 0.20928150, 0.26150429],\n",
" [0.10494646, 0.06329897, 0.33843631, ..., 0.25138417, 0.12470355, 0.23926635],\n",
" ...,\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.11428106, 0.45667490, 0.46820879, ..., 0.32057840, 0.33578536, 0.39012644],\n",
" [0.10441341, 0.45739070, 0.46107352, ..., 0.38467997, 0.38291249, 0.36685589],\n",
" [0.19867736, 0.35519636, 0.44313061, ..., 0.40679252, 0.38067645, 0.30645671],\n",
" ...,\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n",
"\n",
"\n",
" [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.02465414, 0. , 0. , ..., 0. , 0. , 0.03390232],\n",
" [0. , 0. , 0.01830704, ..., 0.05166877, 0.00948385, 0.07453502],\n",
" [0.09921519, 0. , 0.01587192, ..., 0.01620276, 0.05140074, 0.00192392],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" ...,\n",
"\n",
" [[0.40034360, 0.25306445, 0.20217699, ..., 0.09816189, 0.07064310, 0.04974059],\n",
" [0.12567598, 0.21030979, 0.11181555, ..., 0.04278110, 0.11968569, 0.12005232],\n",
" [0.28786880, 0.24030517, 0.22565845, ..., 0. , 0.06418110, 0.05872961],\n",
" ...,\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.38404641, 0.30990323, 0.37156230, ..., 0.18125033, 0.15050662, 0.19619957],\n",
" [0.47285745, 0.40528792, 0.39718056, ..., 0.24709940, 0.04565683, 0.11500744],\n",
" [0.32620737, 0.30072594, 0.30477354, ..., 0.23529193, 0.21356541, 0.16985542],\n",
" ...,\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n",
"\n",
"\n",
" [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.03343770, 0.00123780, 0.05297198, ..., 0.07271163, 0.08656286, 0.14493589],\n",
" [0.11043239, 0.06143146, 0.06362963, ..., 0.08127750, 0.06259022, 0.08315435],\n",
" [0.01767678, 0.00201111, 0.07875030, ..., 0.06963293, 0.08979890, 0.05326346],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.10033827, 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0.15627117, 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0.05144687, 0. , 0. , ..., 0. , 0. , 0.00436414],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" ...,\n",
"\n",
" [[0.25142455, 0.45964020, 0.37346074, ..., 0.04763087, 0. , 0. ],\n",
" [0.19760093, 0.26626948, 0.11190540, ..., 0.03044968, 0. , 0. ],\n",
" [0.16340607, 0.32938001, 0.25689697, ..., 0.05569421, 0. , 0. ],\n",
" ...,\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n",
" [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n",
"\n",
" [[0. , 0. , 0. , ..., 0. , 0.02218930, 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0.02848953],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[0.25810039, 0.63016868, 0.37037861, ..., 0.18704373, 0.08269356, 0.09912672],\n",
" [0.17292863, 0.50678611, 0.40738991, ..., 0.16006103, 0.11725381, 0.09940521],\n",
" [0.24175072, 0.41616210, 0.41256818, ..., 0.13519743, 0.07912572, 0.12846369],\n",
" ...,\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n",
" [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]]])\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
"\n",
"\n",
"#xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
"# print(xs)\n",
"\n",
"x = xs.unsqueeze(1)\n",
"x = model.encoder.embed.conv(x)\n",
"print(x)"
]
},
{
"cell_type": "code",
"execution_count": 206,
"id": "friendly-nightlife",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[[-0.03426375, 0.14291267, -0.06718873, ..., 0.09064753, 0.01809387, -0.04340880],\n",
" [-0.05007839, 0.11054724, -0.10399298, ..., 0.11457238, 0.04244684, -0.01249714],\n",
" [-0.10695291, 0.16910909, -0.08352133, ..., 0.07710276, 0.01168563, -0.03584499],\n",
" ...,\n",
" [-0.06060536, 0.14455931, -0.05470302, ..., 0.05364908, 0.03033342, -0.02610814],\n",
" [-0.08505894, 0.13611752, -0.11132983, ..., 0.13079923, 0.01580139, -0.02281028],\n",
" [-0.10604677, 0.14714901, -0.10885533, ..., 0.08543444, 0.03719445, -0.04634233]],\n",
"\n",
" [[-0.12392755, 0.14486063, -0.05674079, ..., 0.02573164, 0.03128851, 0.00545091],\n",
" [-0.04775286, 0.08473608, -0.08507854, ..., 0.04573154, 0.04240163, 0.01053247],\n",
" [-0.05940291, 0.10023535, -0.08143730, ..., 0.03596500, 0.01673085, 0.02089563],\n",
" ...,\n",
" [-0.09222981, 0.15823206, -0.07700447, ..., 0.08122957, 0.03136991, -0.00646474],\n",
" [-0.07331756, 0.14482647, -0.07838815, ..., 0.10869440, 0.01356864, -0.02777974],\n",
" [-0.07937264, 0.20143102, -0.05544947, ..., 0.10287814, 0.00608235, -0.04799180]],\n",
"\n",
" [[-0.03670349, 0.08931590, -0.08718812, ..., 0.01314050, 0.00642052, 0.00573716],\n",
" [ 0.01089254, 0.11146393, -0.10263617, ..., 0.05070438, 0.01960694, 0.03521532],\n",
" [-0.02182280, 0.11443964, -0.06678198, ..., 0.04327708, 0.00861394, 0.02871092],\n",
" ...,\n",
" [-0.06792898, 0.14376275, -0.07899005, ..., 0.11248926, 0.03208683, -0.03264240],\n",
" [-0.07884051, 0.17024788, -0.08583611, ..., 0.09028331, 0.03588808, -0.02075090],\n",
" [-0.13792302, 0.27163863, -0.23930418, ..., 0.13391261, 0.07521040, -0.08621951]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.02446348, 0.11595841, -0.03591986, ..., 0.06288970, 0.02895011, -0.06532725],\n",
" [-0.05378424, 0.12607370, -0.09023033, ..., 0.09078894, 0.01035743, 0.03701983],\n",
" [-0.04566649, 0.14275314, -0.06686870, ..., 0.09890588, -0.00612222, 0.03439377],\n",
" ...,\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n",
"\n",
" [[-0.01012144, 0.03909408, -0.07077143, ..., 0.00452683, -0.01377654, 0.02897627],\n",
" [-0.00519154, 0.03594019, -0.06831125, ..., 0.05693541, -0.00406374, 0.04561640],\n",
" [-0.01762631, 0.00500899, -0.05886075, ..., 0.02112178, -0.00729015, 0.02782153],\n",
" ...,\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n",
"\n",
" [[-0.03411558, -0.04318277, -0.08497842, ..., -0.04886402, 0.04296734, 0.06151697],\n",
" [ 0.00263296, -0.06913657, -0.08993219, ..., -0.00149064, 0.05696633, 0.03304394],\n",
" [-0.01818341, -0.01178640, -0.09679577, ..., -0.00870231, 0.00362198, 0.01916483],\n",
" ...,\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n",
" [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]]])\n",
"Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [[[-0.54821998, 2.28660274, -1.07501972, ..., 1.45036042, 0.28950194, -0.69454080],\n",
" [-0.80125421, 1.76875579, -1.66388774, ..., 1.83315802, 0.67914939, -0.19995420],\n",
" [-1.71124649, 2.70574546, -1.33634126, ..., 1.23364413, 0.18697014, -0.57351983],\n",
" ...,\n",
" [-0.96968573, 2.31294894, -0.87524825, ..., 0.85838526, 0.48533469, -0.41773027],\n",
" [-1.36094308, 2.17788029, -1.78127730, ..., 2.09278774, 0.25282228, -0.36496443],\n",
" [-1.69674826, 2.35438418, -1.74168527, ..., 1.36695099, 0.59511113, -0.74147725]],\n",
"\n",
" [[-1.98284078, 2.31777000, -0.90785271, ..., 0.41170627, 0.50061619, 0.08721463],\n",
" [-0.76404583, 1.35577726, -1.36125672, ..., 0.73170459, 0.67842603, 0.16851945],\n",
" [-0.95044655, 1.60376561, -1.30299675, ..., 0.57544005, 0.26769355, 0.33433008],\n",
" ...,\n",
" [-1.47567701, 2.53171301, -1.23207152, ..., 1.29967308, 0.50191855, -0.10343577],\n",
" [-1.17308092, 2.31722355, -1.25421047, ..., 1.73911047, 0.21709818, -0.44447583],\n",
" [-1.26996231, 3.22289634, -0.88719147, ..., 1.64605021, 0.09731755, -0.76786882]],\n",
"\n",
" [[-0.58725590, 1.42905438, -1.39500988, ..., 0.21024795, 0.10272825, 0.09179455],\n",
" [ 0.17428070, 1.78342295, -1.64217877, ..., 0.81127012, 0.31371105, 0.56344515],\n",
" [-0.34916472, 1.83103430, -1.06851172, ..., 0.69243336, 0.13782299, 0.45937473],\n",
" ...,\n",
" [-1.08686376, 2.30020404, -1.26384079, ..., 1.79982817, 0.51338923, -0.52227837],\n",
" [-1.26144814, 2.72396612, -1.37337780, ..., 1.44453299, 0.57420933, -0.33201432],\n",
" [-2.20676827, 4.34621811, -3.82886696, ..., 2.14260173, 1.20336640, -1.37951219]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.39141566, 1.85533464, -0.57471782, ..., 1.00623512, 0.46320182, -1.04523599],\n",
" [-0.86054784, 2.01717925, -1.44368529, ..., 1.45262301, 0.16571884, 0.59231722],\n",
" [-0.73066384, 2.28405023, -1.06989920, ..., 1.58249414, -0.09795550, 0.55030036],\n",
" ...,\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n",
"\n",
" [[-0.16194311, 0.62550521, -1.13234293, ..., 0.07242929, -0.22042468, 0.46362036],\n",
" [-0.08306468, 0.57504302, -1.09298003, ..., 0.91096652, -0.06501988, 0.72986233],\n",
" [-0.28202093, 0.08014385, -0.94177192, ..., 0.33794850, -0.11664233, 0.44514441],\n",
" ...,\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n",
"\n",
" [[-0.54584920, -0.69092435, -1.35965478, ..., -0.78182435, 0.68747747, 0.98427159],\n",
" [ 0.04212743, -1.10618520, -1.43891501, ..., -0.02385022, 0.91146135, 0.52870303],\n",
" [-0.29093450, -0.18858244, -1.54873240, ..., -0.13923697, 0.05795169, 0.30663735],\n",
" ...,\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n",
" [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]]])\n",
"Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n",
" [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n",
" [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n",
" ...,\n",
" [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n",
" [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n",
" [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n"
]
}
],
"source": [
"b, c, t, f = paddle.shape(x)\n",
"x = model.encoder.embed.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))\n",
"print(x)\n",
"x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n",
"print(x)\n",
"print(pos_emb)"
]
},
{
"cell_type": "code",
"execution_count": 207,
"id": "guilty-cache",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n",
" [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n",
" [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n",
" ...,\n",
" [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n",
" [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n",
" [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n"
]
}
],
"source": [
"print(pos_emb)"
]
},
{
"cell_type": "code",
"execution_count": 208,
"id": "iraqi-payday",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n",
" 0.0000000e+00 1.0000000e+00]\n",
" [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n",
" 1.0746076e-04 1.0000000e+00]\n",
" [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n",
" 2.1492151e-04 1.0000000e+00]\n",
" ...\n",
" [ 9.5625257e-01 -2.9254240e-01 4.8925215e-01 ... 8.3807874e-01\n",
" 5.1154459e-01 8.5925674e-01]\n",
" [ 2.7049953e-01 -9.6272010e-01 9.9170387e-01 ... 8.3801574e-01\n",
" 5.1163691e-01 8.5920173e-01]\n",
" [-6.6394955e-01 -7.4777740e-01 6.9544029e-01 ... 8.3795273e-01\n",
" 5.1172924e-01 8.5914677e-01]]]\n",
"[1, 5000, 256]\n"
]
}
],
"source": [
"import torch\n",
"import math\n",
"import numpy as np\n",
"\n",
"max_len=5000\n",
"d_model=256\n",
"\n",
"pe = torch.zeros(max_len, d_model)\n",
"position = torch.arange(0, max_len,\n",
" dtype=torch.float32).unsqueeze(1)\n",
"toruch_position = position\n",
"div_term = torch.exp(\n",
" torch.arange(0, d_model, 2, dtype=torch.float32) *\n",
" -(math.log(10000.0) / d_model))\n",
"tourch_div_term = div_term.cpu().detach().numpy()\n",
"\n",
"torhc_sin = torch.sin(position * div_term)\n",
"torhc_cos = torch.cos(position * div_term)\n",
"\n",
"np_sin = np.sin((position * div_term).cpu().detach().numpy())\n",
"np_cos = np.cos((position * div_term).cpu().detach().numpy())\n",
"pe[:, 0::2] = torhc_sin\n",
"pe[:, 1::2] = torhc_cos\n",
"pe = pe.unsqueeze(0) \n",
"tourch_pe = pe.cpu().detach().numpy()\n",
"print(tourch_pe)\n",
"bak_pe = model.encoder.embed.pos_enc.pe\n",
"print(bak_pe.shape)\n",
"model.encoder.embed.pos_enc.pe = paddle.to_tensor(tourch_pe)"
]
},
{
"cell_type": "code",
"execution_count": 210,
"id": "exempt-cloud",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
"\n",
"xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
"#print(xs)\n",
"data = np.load(\".notebook/enc_embed.npz\")\n",
"torch_pos_emb=data['pos_emb']\n",
"torch_xs = data['embed_out']\n",
"print(np.allclose(xs.numpy(), torch_xs))\n",
"print(np.allclose(pos_emb.numpy(), torch_pos_emb))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "composite-involvement",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 269,
"id": "handed-harris",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n",
"True\n",
"True\n",
"True\n",
"True\n",
"False\n",
"True\n",
"[256, 2048]\n",
"[2048]\n",
"[2048, 256]\n",
"[256]\n",
"--------ff-------\n",
"True\n",
"False\n",
"False\n",
"False\n",
"False\n",
"True\n",
"linear_714.w_0 True\n",
"linear_714.b_0 True\n",
"linear_715.w_0 True\n",
"linear_715.b_0 True\n",
"False\n",
"True\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
"\n",
"xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
"masks = masks.astype(paddle.bool)\n",
"mask_pad = masks.logical_not()\n",
"decoding_chunk_size=0\n",
"num_decoding_left_chunks=-1\n",
"chunk_masks = add_optional_chunk_mask(\n",
" xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n",
" decoding_chunk_size, model.encoder.static_chunk_size,\n",
" num_decoding_left_chunks)\n",
"\n",
"#print(chunk_masks)\n",
"data = np.load(\".notebook/enc_embed.npz\")\n",
"torch_pos_emb=data['pos_emb']\n",
"torch_xs = data['embed_out']\n",
"torch_chunk_masks = data['chunk_masks']\n",
"torch_mask_pad = data['mask_pad']\n",
"print(np.allclose(xs.numpy(), torch_xs))\n",
"print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n",
"np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n",
"np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n",
"\n",
"for layer in model.encoder.encoders:\n",
" #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
" print(layer.feed_forward_macaron is not None)\n",
" print(layer.normalize_before)\n",
" \n",
" data = np.load('.notebook/enc_0_norm_ff.npz')\n",
" t_norm_ff = data['norm_ff']\n",
" t_xs = data['xs']\n",
" \n",
" \n",
" x = xs\n",
" print(np.allclose(t_xs, x.numpy()))\n",
" residual = x\n",
" print(np.allclose(t_xs, residual.numpy()))\n",
" x_nrom = layer.norm_ff_macaron(x)\n",
" print(np.allclose(t.numpy(), x_nrom.numpy()))\n",
" print(np.allclose(t_norm_ff, x_nrom.numpy()))\n",
"# for n, p in layer.norm_ff_macaron.state_dict().items():\n",
"# print(n, p)\n",
"# pass\n",
"\n",
" layer.eval()\n",
" x_nrom = paddle.to_tensor(t_norm_ff)\n",
" print(np.allclose(t_norm_ff, x_nrom.numpy()))\n",
" x = residual + layer.ff_scale * layer.feed_forward_macaron(x_nrom)\n",
" \n",
" ps=[]\n",
" for n, p in layer.feed_forward_macaron.state_dict().items():\n",
" #print(n, p)\n",
" ps.append(p)\n",
" print(p.shape)\n",
" pass\n",
"\n",
" x_nrom = paddle.to_tensor(t_norm_ff)\n",
" ff_l_x = layer.feed_forward_macaron.w_1(x_nrom)\n",
" ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n",
" ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n",
" data = np.load('.notebook/enc_0_ff_out.npz', allow_pickle=True)\n",
" t_norm_ff = data['norm_ff']\n",
" t_ff_out = data['ff_out']\n",
" t_ff_l_x = data['ff_l_x']\n",
" t_ff_l_a_x = data['ff_l_a_x']\n",
" t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
" t_ps = data['ps']\n",
" \n",
" print(\"--------ff-------\")\n",
" print(np.allclose(x_nrom.numpy(), t_norm_ff))\n",
" print(np.allclose(x.numpy(), t_ff_out))\n",
" print(np.allclose(ff_l_x.numpy(), t_ff_l_x))\n",
" print(np.allclose(ff_l_a_x.numpy(), t_ff_l_a_x))\n",
" print(np.allclose(ff_l_a_l_x.numpy(), t_ff_l_a_l_x))\n",
" \n",
" print(np.allclose(ff_l_x.numpy(), t_ff_l_x, atol=1e-6))\n",
" for p, t_p in zip(ps, t_ps):\n",
" print(p.name, np.allclose(p.numpy(), t_p.T))\n",
" \n",
" \n",
"# residual = x\n",
"# x = layer.norm_mha(x)\n",
"# x_q = x\n",
" \n",
" data = np.load('.notebook/enc_0_selattn_out.npz', allow_pickle=True)\n",
" tx_q = data['x_q']\n",
" tx = data['x']\n",
" tpos_emb=data['pos_emb']\n",
" tmask=data['mask']\n",
" tt_x_att=data['x_att']\n",
" x_q = paddle.to_tensor(tx_q)\n",
" x = paddle.to_tensor(tx)\n",
" pos_emb = paddle.to_tensor(tpos_emb)\n",
" mask = paddle.to_tensor(tmask)\n",
" \n",
" x_att = layer.self_attn(x_q, x, x, pos_emb, mask)\n",
" print(np.allclose(x_att.numpy(), t_x_att))\n",
" print(np.allclose(x_att.numpy(), t_x_att, atol=1e-6))\n",
" \n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 270,
"id": "sonic-thumb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"False\n",
"True\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
"\n",
"xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
"masks = masks.astype(paddle.bool)\n",
"mask_pad = masks.logical_not()\n",
"decoding_chunk_size=0\n",
"num_decoding_left_chunks=-1\n",
"chunk_masks = add_optional_chunk_mask(\n",
" xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n",
" decoding_chunk_size, model.encoder.static_chunk_size,\n",
" num_decoding_left_chunks)\n",
"\n",
"#print(chunk_masks)\n",
"data = np.load(\".notebook/enc_embed.npz\")\n",
"torch_pos_emb=data['pos_emb']\n",
"torch_xs = data['embed_out']\n",
"torch_chunk_masks = data['chunk_masks']\n",
"torch_mask_pad = data['mask_pad']\n",
"print(np.allclose(xs.numpy(), torch_xs))\n",
"print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n",
"np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n",
"np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n",
"\n",
"\n",
"for layer in model.encoder.encoders:\n",
" xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
" break\n",
"data = np.load('.notebook/enc_0.npz')\n",
"torch_xs = data['enc_0']\n",
"print(np.allclose(xs.numpy(), torch_xs))\n",
"print(np.allclose(xs.numpy(), torch_xs, atol=1e-6))\n"
]
},
{
"cell_type": "code",
"execution_count": 273,
"id": "brave-latino",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"--------layers_______\n",
"False\n",
"True\n",
"[[-0.70194244 0.56254214 0.6880346 ... 1.1237319 0.7803924\n",
" 1.1369387 ]\n",
" [-0.7787783 0.3912667 0.71887773 ... 1.251882 0.886168\n",
" 1.3173451 ]\n",
" [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n",
" 1.2903278 ]\n",
" ...\n",
" [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n",
" 1.3239655 ]\n",
" [-1.1654118 0.6819967 0.6939453 ... 1.2238353 0.8028295\n",
" 1.4506507 ]\n",
" [-1.2732092 0.7145806 0.75819594 ... 0.94154835 0.8774845\n",
" 1.2623049 ]]\n",
"xxxxxx\n",
"[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n",
" 1.1369387 ]\n",
" [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n",
" 1.3173454 ]\n",
" [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n",
" 1.2903274 ]\n",
" ...\n",
" [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n",
" 1.3239657 ]\n",
" [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n",
" 1.4506509 ]\n",
" [-1.273209 0.71458095 0.75819623 ... 0.9415484 0.8774842\n",
" 1.2623055 ]]\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"masks = make_non_pad_mask(feat_len).unsqueeze(1)\n",
"\n",
"xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n",
"masks = masks.astype(paddle.bool)\n",
"mask_pad = masks.logical_not()\n",
"decoding_chunk_size=0\n",
"num_decoding_left_chunks=-1\n",
"chunk_masks = add_optional_chunk_mask(\n",
" xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n",
" decoding_chunk_size, model.encoder.static_chunk_size,\n",
" num_decoding_left_chunks)\n",
"\n",
"#print(chunk_masks)\n",
"data = np.load(\".notebook/enc_embed.npz\")\n",
"torch_pos_emb=data['pos_emb']\n",
"torch_xs = data['embed_out']\n",
"torch_chunk_masks = data['chunk_masks']\n",
"torch_mask_pad = data['mask_pad']\n",
"print(np.allclose(xs.numpy(), torch_xs))\n",
"print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n",
"np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n",
"np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n",
"\n",
"print(\"--------layers_______\")\n",
"i =0\n",
"for layer in model.encoder.encoders:\n",
" xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
" i+=1\n",
"# if i == 2:\n",
"# data = np.load('.notebook/enc_2.npz')\n",
"# torch_xs = data['enc_2']\n",
"# print(np.allclose(xs.numpy(), torch_xs))\n",
"# print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n",
"# print(xs[0].numpy())\n",
"# print('xxxxxx')\n",
"# print(torch_xs[0])\n",
"# print('----i==2')\n",
"data = np.load('.notebook/enc_all.npz')\n",
"torch_xs = data['enc_all']\n",
"print(np.allclose(xs.numpy(), torch_xs))\n",
"print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n",
"print(xs[0].numpy())\n",
"print('xxxxxx')\n",
"print(torch_xs[0])"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "municipal-stock",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 278,
"id": "macro-season",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-0.7019424 0.5625421 0.68803453 ... 1.1237317 0.7803923\n",
" 1.1369386 ]\n",
" [-0.7787783 0.39126673 0.71887773 ... 1.251882 0.886168\n",
" 1.3173451 ]\n",
" [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n",
" 1.2903278 ]\n",
" ...\n",
" [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n",
" 1.3239655 ]\n",
" [-1.1654117 0.68199664 0.6939452 ... 1.2238352 0.8028294\n",
" 1.4506506 ]\n",
" [-1.2732091 0.71458054 0.7581958 ... 0.9415482 0.8774844\n",
" 1.2623048 ]]\n",
"---\n",
"[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n",
" 1.1369387 ]\n",
" [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n",
" 1.3173454 ]\n",
" [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n",
" 1.2903274 ]\n",
" ...\n",
" [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n",
" 1.3239657 ]\n",
" [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n",
" 1.4506509 ]\n",
" [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n",
" 1.2623053 ]]\n",
"False\n",
"True\n",
"False\n"
]
}
],
"source": [
"encoder_out, mask = model.encoder(feat, feat_len)\n",
"print(encoder_out.numpy()[0])\n",
"print(\"---\")\n",
"print(torch_encoder_out[0])\n",
"print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n",
"print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5))\n",
"print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "associate-sampling",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}