You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/.notebook/wenet_model.ipynb

5016 lines
299 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cfb832c0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/wenet\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/wenet'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd /workspace/wenet/\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "62277538",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import argparse\n",
"import copy\n",
"import logging\n",
"import os\n",
"\n",
"import torch\n",
"import torch.distributed as dist\n",
"import torch.optim as optim\n",
"import yaml\n",
"from tensorboardX import SummaryWriter\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from wenet.dataset.dataset import AudioDataset, CollateFunc\n",
"from wenet.transformer.asr_model import init_asr_model\n",
"from wenet.utils.checkpoint import load_checkpoint, save_checkpoint\n",
"from wenet.utils.executor import Executor\n",
"from wenet.utils.scheduler import WarmupLR\n",
"\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = \"0\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2f6ea33a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'config': 'examples/aishell/s0/conf/train_conformer.yaml', 'train_data': 'examples/aishell/s0/raw_wav/train/format.data', 'cv_data': 'examples/aishell/s0/raw_wav/dev/format.data', 'gpu': -1, 'model_dir': None, 'checkpoint': None, 'tensorboard_dir': 'tensorboard', 'rank': 0, 'world_size': -1, 'dist_backend': 'nccl', 'init_method': None, 'num_workers': 0, 'pin_memory': False, 'cmvn': 'examples/aishell/s0/raw_wav/train/global_cmvn'}\n"
]
}
],
"source": [
"parser = argparse.ArgumentParser(description='training your network')\n",
"parser.add_argument('--config', default=\"examples/aishell/s0/conf/train_conformer.yaml\", help='config file')\n",
"parser.add_argument('--train_data', default=\"examples/aishell/s0/raw_wav/train/format.data\", help='train data file')\n",
"parser.add_argument('--cv_data', default=\"examples/aishell/s0/raw_wav/dev/format.data\", help='cv data file')\n",
"parser.add_argument('--gpu',\n",
" type=int,\n",
" default=-1,\n",
" help='gpu id for this local rank, -1 for cpu')\n",
"parser.add_argument('--model_dir' , help='save model dir')\n",
"parser.add_argument('--checkpoint', help='checkpoint model')\n",
"parser.add_argument('--tensorboard_dir',\n",
" default='tensorboard',\n",
" help='tensorboard log dir')\n",
"parser.add_argument('--ddp.rank',\n",
" dest='rank',\n",
" default=0,\n",
" type=int,\n",
" help='global rank for distributed training')\n",
"parser.add_argument('--ddp.world_size',\n",
" dest='world_size',\n",
" default=-1,\n",
" type=int,\n",
" help='''number of total processes/gpus for\n",
" distributed training''')\n",
"parser.add_argument('--ddp.dist_backend',\n",
" dest='dist_backend',\n",
" default='nccl',\n",
" choices=['nccl', 'gloo'],\n",
" help='distributed backend')\n",
"parser.add_argument('--ddp.init_method',\n",
" dest='init_method',\n",
" default=None,\n",
" help='ddp init method')\n",
"parser.add_argument('--num_workers',\n",
" default=0,\n",
" type=int,\n",
" help='num of subprocess workers for reading')\n",
"parser.add_argument('--pin_memory',\n",
" action='store_true',\n",
" default=False,\n",
" help='Use pinned memory buffers used for reading')\n",
"parser.add_argument('--cmvn', default=\"examples/aishell/s0/raw_wav/train/global_cmvn\", help='global cmvn file')\n",
"\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f5d6af9b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Namespace(checkpoint=None, cmvn='examples/aishell/s0/raw_wav/train/global_cmvn', config='examples/aishell/s0/conf/train_conformer.yaml', cv_data='examples/aishell/s0/raw_wav/dev/format.data', dist_backend='nccl', gpu=-1, init_method=None, model_dir=None, num_workers=0, pin_memory=False, rank=0, tensorboard_dir='tensorboard', train_data='examples/aishell/s0/raw_wav/train/format.data', world_size=-1)\n"
]
}
],
"source": [
"# Set random seed\n",
"torch.manual_seed(777)\n",
"print(args)\n",
"with open(args.config, 'r') as fin:\n",
" configs = yaml.load(fin, Loader=yaml.FullLoader)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "264bd353",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7507 batches\n",
"896\n"
]
}
],
"source": [
"raw_wav = configs['raw_wav']\n",
"\n",
"train_collate_func = CollateFunc(**configs['collate_conf'],\n",
" raw_wav=raw_wav)\n",
"\n",
"cv_collate_conf = copy.deepcopy(configs['collate_conf'])\n",
"# no augmenation on cv set\n",
"cv_collate_conf['spec_aug'] = False\n",
"cv_collate_conf['spec_sub'] = False\n",
"if raw_wav:\n",
" cv_collate_conf['feature_dither'] = 0.0\n",
" cv_collate_conf['speed_perturb'] = False\n",
" cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0\n",
"cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)\n",
"\n",
"dataset_conf = configs.get('dataset_conf', {})\n",
"train_dataset = AudioDataset(args.train_data,\n",
" **dataset_conf,\n",
" raw_wav=raw_wav)\n",
"cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)\n",
"# 120098 data/train/wav.scp\n",
"print(len(train_dataset), 'batches')\n",
"# 14326 data/dev/wav.scp\n",
"print(len(cv_dataset))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "88863d3c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"896\n"
]
}
],
"source": [
"train_sampler = None\n",
"cv_sampler = None\n",
"train_data_loader = DataLoader(train_dataset,\n",
" collate_fn=train_collate_func,\n",
" sampler=train_sampler,\n",
" #shuffle=(train_sampler is None),\n",
" shuffle=False,\n",
" pin_memory=args.pin_memory,\n",
" batch_size=1,\n",
" num_workers=args.num_workers)\n",
"cv_data_loader = DataLoader(cv_dataset,\n",
" collate_fn=cv_collate_func,\n",
" sampler=cv_sampler,\n",
" shuffle=False,\n",
" batch_size=1,\n",
" pin_memory=args.pin_memory,\n",
" num_workers=args.num_workers)\n",
"print(len(cv_data_loader))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "10d5acd4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4233 vocab\n",
"80 feat dim\n"
]
}
],
"source": [
"if raw_wav:\n",
" input_dim = configs['collate_conf']['feature_extraction_conf'][\n",
" 'mel_bins']\n",
"else:\n",
" input_dim = train_dataset.input_dim\n",
"vocab_size = train_dataset.output_dim\n",
"print(vocab_size, 'vocab')\n",
"print(input_dim , 'feat dim')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0380ef5a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"examples/aishell/s0/raw_wav/train/global_cmvn\n"
]
}
],
"source": [
"# Save configs to model_dir/train.yaml for inference and export\n",
"configs['input_dim'] = input_dim\n",
"configs['output_dim'] = vocab_size\n",
"configs['cmvn_file'] = args.cmvn\n",
"configs['is_json_cmvn'] = raw_wav\n",
"print(args.cmvn)\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "15ebf2bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(80,)\n",
"(80,)\n",
"[ 9.87176362 9.93891555 10.23818678 10.85971412 11.68652649 12.2548801\n",
" 12.65768161 12.86138996 12.80733912 12.56625574 12.32007066 12.13879205\n",
" 12.31318868 12.55255216 12.61223855 12.56974526 12.38972728 12.14383338\n",
" 12.09285066 11.79395822 11.62259065 11.9263303 11.8154422 11.95122567\n",
" 11.83180553 11.88788759 11.79014437 11.88072035 11.90005711 11.97348142\n",
" 12.00982189 12.00881339 12.02619706 12.10479646 12.21555081 12.34399304\n",
" 12.45014401 12.4966879 12.48653775 12.3550783 12.39291732 12.2553737\n",
" 12.26496277 12.25314244 12.32545763 12.43359839 12.54867439 12.6763342\n",
" 12.80920698 12.92934681 12.96115138 12.96883353 12.99593057 13.04728142\n",
" 13.0588804 13.05737948 12.99921175 12.93402238 12.87429219 12.71652995\n",
" 12.48942004 12.27478385 12.26163069 12.28631891 12.31956049 12.4229073\n",
" 12.51480191 12.5785164 12.64719411 12.73762568 12.80017069 12.86872766\n",
" 12.96666856 13.06478583 13.15915908 13.27284306 13.31081821 13.23904279\n",
" 12.87936075 11.18310185]\n",
"[0.61219383 0.49700994 0.33439025 0.31503119 0.29640823 0.28411759\n",
" 0.26972922 0.25610475 0.24632936 0.24610228 0.24733299 0.24426536\n",
" 0.23751781 0.22987273 0.22659963 0.2268427 0.23059031 0.23420722\n",
" 0.23771761 0.2411352 0.24404673 0.24557175 0.24724932 0.25055198\n",
" 0.25482755 0.2602407 0.26363878 0.26503898 0.2648467 0.26435072\n",
" 0.26353625 0.26364794 0.26411054 0.26339948 0.26212082 0.26146597\n",
" 0.26196556 0.26365859 0.26592959 0.26963884 0.27392766 0.27818809\n",
" 0.28313664 0.2863325 0.28713431 0.28649323 0.28636648 0.2867843\n",
" 0.28635904 0.28562022 0.28492711 0.28429201 0.28402977 0.28401045\n",
" 0.28560797 0.28728033 0.28969549 0.29351627 0.29826453 0.30572631\n",
" 0.31811682 0.32887739 0.33288219 0.33326245 0.33014147 0.32403202\n",
" 0.31903576 0.31316258 0.30741037 0.30370692 0.30204833 0.30049064\n",
" 0.29901079 0.29824511 0.29812308 0.29753329 0.29779342 0.30175296\n",
" 0.30955538 0.32904205]\n"
]
}
],
"source": [
"import json\n",
"import math\n",
"import numpy as np\n",
"def _load_json_cmvn(json_cmvn_file):\n",
" \"\"\" Load the json format cmvn stats file and calculate cmvn\n",
"\n",
" Args:\n",
" json_cmvn_file: cmvn stats file in json format\n",
"\n",
" Returns:\n",
" a numpy array of [means, vars]\n",
" \"\"\"\n",
" with open(json_cmvn_file) as f:\n",
" cmvn_stats = json.load(f)\n",
"\n",
" means = cmvn_stats['mean_stat']\n",
" variance = cmvn_stats['var_stat']\n",
" count = cmvn_stats['frame_num']\n",
" for i in range(len(means)):\n",
" means[i] /= count\n",
" variance[i] = variance[i] / count - means[i] * means[i]\n",
" if variance[i] < 1.0e-20:\n",
" variance[i] = 1.0e-20\n",
" variance[i] = 1.0 / math.sqrt(variance[i])\n",
" cmvn = np.array([means, variance])\n",
" return cmvn\n",
"\n",
"\n",
"def _load_kaldi_cmvn(kaldi_cmvn_file):\n",
" \"\"\" Load the kaldi format cmvn stats file and calculate cmvn\n",
"\n",
" Args:\n",
" kaldi_cmvn_file: kaldi text style global cmvn file, which\n",
" is generated by:\n",
" compute-cmvn-stats --binary=false scp:feats.scp global_cmvn\n",
"\n",
" Returns:\n",
" a numpy array of [means, vars]\n",
" \"\"\"\n",
" means = []\n",
" variance = []\n",
" with open(kaldi_cmvn_file, 'r') as fid:\n",
" # kaldi binary file start with '\\0B'\n",
" if fid.read(2) == '\\0B':\n",
" logger.error('kaldi cmvn binary file is not supported, please '\n",
" 'recompute it by: compute-cmvn-stats --binary=false '\n",
" ' scp:feats.scp global_cmvn')\n",
" sys.exit(1)\n",
" fid.seek(0)\n",
" arr = fid.read().split()\n",
" assert (arr[0] == '[')\n",
" assert (arr[-2] == '0')\n",
" assert (arr[-1] == ']')\n",
" feat_dim = int((len(arr) - 2 - 2) / 2)\n",
" for i in range(1, feat_dim + 1):\n",
" means.append(float(arr[i]))\n",
" count = float(arr[feat_dim + 1])\n",
" for i in range(feat_dim + 2, 2 * feat_dim + 2):\n",
" variance.append(float(arr[i]))\n",
"\n",
" for i in range(len(means)):\n",
" means[i] /= count\n",
" variance[i] = variance[i] / count - means[i] * means[i]\n",
" if variance[i] < 1.0e-20:\n",
" variance[i] = 1.0e-20\n",
" variance[i] = 1.0 / math.sqrt(variance[i])\n",
" cmvn = np.array([means, variance])\n",
" return cmvn\n",
"\n",
"\n",
"def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):\n",
" npzfile = np.load(npz_cmvn_file)\n",
" means = npzfile[\"mean\"] #(1, D)\n",
" std = npzfile[\"std\"] #(1, D)\n",
" std = np.clip(std, eps, None)\n",
" variance = 1.0 / std\n",
" cmvn = np.array([means, variance])\n",
" return cmvn\n",
"\n",
"\n",
"def load_cmvn(cmvn_file: str, filetype: str):\n",
" \"\"\"load cmvn from file.\n",
"\n",
" Args:\n",
" cmvn_file (str): cmvn path.\n",
" filetype (str): file type, optional[npz, json, kaldi].\n",
"\n",
" Raises:\n",
" ValueError: file type not support.\n",
"\n",
" Returns:\n",
" Tuple[np.ndarray, np.ndarray]: mean, istd\n",
" \"\"\"\n",
" assert filetype in ['npz', 'json', 'kaldi'], filetype\n",
" filetype = filetype.lower()\n",
" if filetype == \"json\":\n",
" cmvn = _load_json_cmvn(cmvn_file)\n",
" elif filetype == \"kaldi\":\n",
" cmvn = _load_kaldi_cmvn(cmvn_file)\n",
" elif filetype == \"npz\":\n",
" cmvn = _load_npz_cmvn(cmvn_file)\n",
" else:\n",
" raise ValueError(f\"cmvn file type no support: {filetype}\")\n",
" return cmvn[0], cmvn[1]\n",
"\n",
"mean, istd = load_cmvn(args.cmvn, 'json')\n",
"print(mean.shape)\n",
"print(istd.shape)\n",
"print(mean)\n",
"print(istd)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3cfa5e23",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ASRModel(\n",
" (encoder): ConformerEncoder(\n",
" (global_cmvn): GlobalCMVN()\n",
" (embed): Conv2dSubsampling4(\n",
" (conv): Sequential(\n",
" (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))\n",
" (1): ReLU()\n",
" (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
" (3): ReLU()\n",
" )\n",
" (out): Sequential(\n",
" (0): Linear(in_features=4864, out_features=256, bias=True)\n",
" )\n",
" (pos_enc): RelPositionalEncoding(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (encoders): ModuleList(\n",
" (0): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (1): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (2): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (3): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (4): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (5): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (6): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (7): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (8): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (9): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (10): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (11): ConformerEncoderLayer(\n",
" (self_attn): RelPositionMultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (feed_forward_macaron): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): Swish()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (conv_module): ConvolutionModule(\n",
" (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
" (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n",
" (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
" (activation): Swish()\n",
" )\n",
" (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (decoder): TransformerDecoder(\n",
" (embed): Sequential(\n",
" (0): Embedding(4233, 256)\n",
" (1): PositionalEncoding(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (output_layer): Linear(in_features=256, out_features=4233, bias=True)\n",
" (decoders): ModuleList(\n",
" (0): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (1): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (2): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (3): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (4): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" (5): DecoderLayer(\n",
" (self_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (src_attn): MultiHeadedAttention(\n",
" (linear_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_k): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_v): Linear(in_features=256, out_features=256, bias=True)\n",
" (linear_out): Linear(in_features=256, out_features=256, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (feed_forward): PositionwiseFeedForward(\n",
" (w_1): Linear(in_features=256, out_features=2048, bias=True)\n",
" (activation): ReLU()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (w_2): Linear(in_features=2048, out_features=256, bias=True)\n",
" )\n",
" (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n",
" (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (ctc): CTC(\n",
" (ctc_lo): Linear(in_features=256, out_features=4233, bias=True)\n",
" (ctc_loss): CTCLoss()\n",
" )\n",
" (criterion_att): LabelSmoothingLoss(\n",
" (criterion): KLDivLoss()\n",
" )\n",
")\n"
]
}
],
"source": [
"# Init asr model from configs\n",
"model = init_asr_model(configs)\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3c780af5",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def summary(layer, print_func=print):\n",
" num_params = num_elements = 0\n",
" for name, param in layer.state_dict().items():\n",
" if print_func:\n",
" print_func(\n",
" \"{} | {} | {}\".format(name, param.shape, np.prod(param.shape)))\n",
" num_elements += np.prod(param.shape)\n",
" num_params += 1\n",
" if print_func:\n",
" print_func(\n",
" f\"Total parameters: {num_params}, {num_elements} elements.\"\n",
" )\n",
" \n",
"def print_params(model, print_func=print):\n",
" if print_func is None:\n",
" return\n",
" total = 0.0\n",
" num_params = 0.0\n",
" for n, p in model.named_parameters():\n",
" msg = f\"{n} | {p.shape} | {np.prod(p.shape)} | {p.requires_grad}\"\n",
" total += np.prod(p.shape)\n",
" num_params += 1\n",
" if print_func:\n",
" print_func(msg)\n",
" if print_func:\n",
" print_func(f\"Total parameters: {num_params}, {total} elements.\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e159a200",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoder.global_cmvn.mean | torch.Size([80]) | 80\n",
"encoder.global_cmvn.istd | torch.Size([80]) | 80\n",
"encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304\n",
"encoder.embed.conv.0.bias | torch.Size([256]) | 256\n",
"encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824\n",
"encoder.embed.conv.2.bias | torch.Size([256]) | 256\n",
"encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184\n",
"encoder.embed.out.0.bias | torch.Size([256]) | 256\n",
"encoder.after_norm.weight | torch.Size([256]) | 256\n",
"encoder.after_norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.0.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.0.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.1.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.1.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.2.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.2.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.3.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.3.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.4.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.4.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.5.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.5.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.6.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.6.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.7.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.7.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.8.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.8.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.9.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.9.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.10.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.10.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n",
"encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n",
"encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n",
"encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n",
"encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n",
"encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n",
"encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n",
"encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256\n",
"encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.conv_module.norm.running_mean | torch.Size([256]) | 256\n",
"encoder.encoders.11.conv_module.norm.running_var | torch.Size([256]) | 256\n",
"encoder.encoders.11.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n",
"encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n",
"encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256\n",
"encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256\n",
"encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072\n",
"encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256\n",
"decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648\n",
"decoder.after_norm.weight | torch.Size([256]) | 256\n",
"decoder.after_norm.bias | torch.Size([256]) | 256\n",
"decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648\n",
"decoder.output_layer.bias | torch.Size([4233]) | 4233\n",
"decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.norm1.weight | torch.Size([256]) | 256\n",
"decoder.decoders.0.norm1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.norm2.weight | torch.Size([256]) | 256\n",
"decoder.decoders.0.norm2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.norm3.weight | torch.Size([256]) | 256\n",
"decoder.decoders.0.norm3.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.norm1.weight | torch.Size([256]) | 256\n",
"decoder.decoders.1.norm1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.norm2.weight | torch.Size([256]) | 256\n",
"decoder.decoders.1.norm2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.norm3.weight | torch.Size([256]) | 256\n",
"decoder.decoders.1.norm3.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.norm1.weight | torch.Size([256]) | 256\n",
"decoder.decoders.2.norm1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.norm2.weight | torch.Size([256]) | 256\n",
"decoder.decoders.2.norm2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.norm3.weight | torch.Size([256]) | 256\n",
"decoder.decoders.2.norm3.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.norm1.weight | torch.Size([256]) | 256\n",
"decoder.decoders.3.norm1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.norm2.weight | torch.Size([256]) | 256\n",
"decoder.decoders.3.norm2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.norm3.weight | torch.Size([256]) | 256\n",
"decoder.decoders.3.norm3.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.norm1.weight | torch.Size([256]) | 256\n",
"decoder.decoders.4.norm1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.norm2.weight | torch.Size([256]) | 256\n",
"decoder.decoders.4.norm2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.norm3.weight | torch.Size([256]) | 256\n",
"decoder.decoders.4.norm3.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n",
"decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n",
"decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n",
"decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n",
"decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.norm1.weight | torch.Size([256]) | 256\n",
"decoder.decoders.5.norm1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.norm2.weight | torch.Size([256]) | 256\n",
"decoder.decoders.5.norm2.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.norm3.weight | torch.Size([256]) | 256\n",
"decoder.decoders.5.norm3.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256\n",
"decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072\n",
"decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256\n",
"ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648\n",
"ctc.ctc_lo.bias | torch.Size([4233]) | 4233\n",
"Total parameters: 701, 49355454.0 elements.\n"
]
}
],
"source": [
"summary(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8494c6ab",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0648a969",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304 | True\n",
"encoder.embed.conv.0.bias | torch.Size([256]) | 256 | True\n",
"encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824 | True\n",
"encoder.embed.conv.2.bias | torch.Size([256]) | 256 | True\n",
"encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184 | True\n",
"encoder.embed.out.0.bias | torch.Size([256]) | 256 | True\n",
"encoder.after_norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.after_norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n",
"encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n",
"encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n",
"encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n",
"encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n",
"encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n",
"encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256 | True\n",
"encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n",
"encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256 | True\n",
"decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648 | True\n",
"decoder.after_norm.weight | torch.Size([256]) | 256 | True\n",
"decoder.after_norm.bias | torch.Size([256]) | 256 | True\n",
"decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648 | True\n",
"decoder.output_layer.bias | torch.Size([4233]) | 4233 | True\n",
"decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.norm1.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.norm1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.norm2.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.norm2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.norm3.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.norm3.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.norm1.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.norm1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.norm2.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.norm2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.norm3.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.norm3.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.norm1.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.norm1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.norm2.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.norm2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.norm3.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.norm3.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.norm1.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.norm1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.norm2.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.norm2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.norm3.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.norm3.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.norm1.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.norm1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.norm2.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.norm2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.norm3.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.norm3.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n",
"decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n",
"decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n",
"decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.norm1.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.norm1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.norm2.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.norm2.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.norm3.weight | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.norm3.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256 | True\n",
"decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n",
"decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256 | True\n",
"ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648 | True\n",
"ctc.ctc_lo.bias | torch.Size([4233]) | 4233 | True\n",
"Total parameters: 663.0, 49349138.0 elements.\n"
]
}
],
"source": [
"print_params(model)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5ad6de2a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n",
"torch.Size([16, 207, 80])\n",
"tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n",
" [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n",
" [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n",
" ...,\n",
" [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n",
" [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n",
" [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n",
"\n",
" [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n",
" [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n",
" [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n",
" ...,\n",
" [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n",
" [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n",
" [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n",
"\n",
" [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n",
" [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n",
" [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n",
" ...,\n",
" [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n",
"\n",
" ...,\n",
"\n",
" [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n",
" [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n",
" [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n",
" ...,\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n",
"\n",
" [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n",
" [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n",
" [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n",
" ...,\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n",
"\n",
" [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n",
" [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n",
" [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n",
" ...,\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n",
"tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n",
" 166, 163], dtype=torch.int32)\n",
"tensor([[2995, 3116, 1209, 565, -1, -1],\n",
" [ 236, 1176, 331, 66, 3925, 4077],\n",
" [2693, 524, 234, 1145, 366, -1],\n",
" [3875, 4211, 3062, 700, -1, -1],\n",
" [ 272, 987, 1134, 494, 2959, -1],\n",
" [1936, 3715, 120, 2553, 2695, 2710],\n",
" [ 25, 1149, 3930, -1, -1, -1],\n",
" [1753, 1778, 1237, 482, 3925, 110],\n",
" [3703, 2, 565, 3827, -1, -1],\n",
" [1150, 2734, 10, 2478, 3490, -1],\n",
" [ 426, 811, 95, 489, 144, -1],\n",
" [2313, 2006, 489, 975, -1, -1],\n",
" [3702, 3414, 205, 1488, 2966, 1347],\n",
" [ 70, 1741, 702, 1666, -1, -1],\n",
" [ 703, 1778, 1030, 849, -1, -1],\n",
" [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n",
"tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)\n"
]
}
],
"source": [
"for batch in cv_data_loader:\n",
" keys, feat, text, feat_len, text_len = batch\n",
" print(keys)\n",
" print(feat.shape)\n",
" print(feat)\n",
" print(feat_len)\n",
" print(text)\n",
" print(text_len)\n",
" np.savez('data.npz', keys=keys, feat=feat.numpy(), feat_len=feat_len.numpy(), text=text.numpy(), text_len=text_len.numpy())\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "852a9c95",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CODE_OF_CONDUCT.md data.npz install.sh README.md\t tools\r\n",
"CONTRIBUTING.md docs LICENSE\t requirements.txt venv\r\n",
"CPPLINT.cfg\t examples Makefile\t runtime\t wenet\r\n"
]
}
],
"source": [
"!ls\n",
"!cp data.npz /workspace/DeepSpeech-2.x/.notebook"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "cde24c4e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(111.9988)\n",
"tensor(830.9634, grad_fn=<SumBackward0>)\n",
"tensor([False, False, False, False, False, True, True, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" True, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, False, True, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, False, False, False, True,\n",
" False, False, False, False, False, False, True, False, False, False,\n",
" False, False, True, True, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, True, True, False, False, False, False, False,\n",
" True, True])\n",
"tensor(669.4633, grad_fn=<SumBackward0>)\n",
"tensor(142.4888, grad_fn=<AddBackward0>) tensor(41.8415, grad_fn=<DivBackward0>) tensor(377.3326, grad_fn=<DivBackward0>)\n"
]
}
],
"source": [
"model.cpu().eval()\n",
"total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n",
" text, text_len)\n",
"print(total_loss, attention_loss, ctc_loss )"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "be5b2a2c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cpu\n"
]
}
],
"source": [
"print(total_loss.device)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "5b791771",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(112., device='cuda:0')\n",
"tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)\n",
"tensor([False, False, False, False, False, True, True, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" True, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, False, True, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, False, False, False, True,\n",
" False, False, False, False, False, False, True, False, False, False,\n",
" False, False, True, True, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, True, True, False, False, False, False, False,\n",
" True, True], device='cuda:0')\n",
"tensor(669.4634, device='cuda:0', grad_fn=<SumBackward0>)\n",
"cuda:0\n",
"142.4888 41.84146 377.33258\n"
]
}
],
"source": [
"model.cuda().eval()\n",
"feat=feat.cuda()\n",
"feat_len=feat_len.cuda()\n",
"text=text.cuda()\n",
"text_len=text_len.cuda()\n",
"\n",
"total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n",
" text, text_len)\n",
"print(total_loss.device)\n",
"print(total_loss.cpu().data.numpy(), attention_loss.cpu().data.numpy(), ctc_loss.cpu().data.numpy() )"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "1baef537",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([16, 51, 256])\n",
"torch.Size([16, 1, 51])\n",
"tensor([[-0.7019, 0.5625, 0.6880, ..., 1.1237, 0.7804, 1.1369],\n",
" [-0.7788, 0.3913, 0.7189, ..., 1.2519, 0.8862, 1.3173],\n",
" [-0.9591, 0.6346, 0.8767, ..., 0.9818, 0.7440, 1.2903],\n",
" ...,\n",
" [-1.0732, 0.6724, 0.9230, ..., 0.9075, 0.8177, 1.3240],\n",
" [-1.1654, 0.6820, 0.6939, ..., 1.2238, 0.8028, 1.4507],\n",
" [-1.2732, 0.7146, 0.7582, ..., 0.9415, 0.8775, 1.2623]],\n",
" device='cuda:0', grad_fn=<SelectBackward>)\n"
]
}
],
"source": [
"encoder_out, encoder_mask = model.encoder(feat, feat_len)\n",
"print(encoder_out.shape)\n",
"print(encoder_mask.shape)\n",
"print(encoder_out[0])\n",
"\n",
"np.savez('/workspace/DeepSpeech-2.x/.notebook/encoder.npz',\n",
" mask=encoder_mask.cpu().detach().numpy(), \n",
" out=encoder_out.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e22c782",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 20,
"id": "30b6b946",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 9.871763 9.938915 10.238187 10.8597145 11.686526 12.25488\n",
" 12.657681 12.86139 12.807339 12.566256 12.32007 12.138792\n",
" 12.313189 12.552552 12.612239 12.569745 12.389728 12.143833\n",
" 12.092851 11.793959 11.622591 11.926331 11.815442 11.951225\n",
" 11.831805 11.887888 11.790144 11.88072 11.900057 11.973481\n",
" 12.009822 12.008814 12.026197 12.104796 12.21555 12.343993\n",
" 12.450144 12.496688 12.486538 12.355079 12.392918 12.255374\n",
" 12.264963 12.253142 12.325458 12.4335985 12.548675 12.676334\n",
" 12.809207 12.929347 12.961151 12.968834 12.995931 13.047281\n",
" 13.058881 13.05738 12.999211 12.934022 12.874292 12.71653\n",
" 12.48942 12.274784 12.261631 12.286319 12.31956 12.422907\n",
" 12.514802 12.578516 12.647194 12.737626 12.800171 12.868728\n",
" 12.966668 13.064786 13.159159 13.272843 13.310819 13.239043\n",
" 12.879361 11.183102 ] float32\n",
"encoder.embed.out.0.weight: (256, 4864) -> (4864, 256)\n",
"encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.0.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.0.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.0.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.0.conv_module.norm.running_mean -> encoder.encoders.0.conv_module.norm._mean\n",
"encoder.encoders.0.conv_module.norm.running_var -> encoder.encoders.0.conv_module.norm._variance\n",
"encoder.encoders.0.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.1.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.1.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.1.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.1.conv_module.norm.running_mean -> encoder.encoders.1.conv_module.norm._mean\n",
"encoder.encoders.1.conv_module.norm.running_var -> encoder.encoders.1.conv_module.norm._variance\n",
"encoder.encoders.1.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.2.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.2.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.2.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.2.conv_module.norm.running_mean -> encoder.encoders.2.conv_module.norm._mean\n",
"encoder.encoders.2.conv_module.norm.running_var -> encoder.encoders.2.conv_module.norm._variance\n",
"encoder.encoders.2.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.3.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.3.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.3.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.3.conv_module.norm.running_mean -> encoder.encoders.3.conv_module.norm._mean\n",
"encoder.encoders.3.conv_module.norm.running_var -> encoder.encoders.3.conv_module.norm._variance\n",
"encoder.encoders.3.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.4.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.4.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.4.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.4.conv_module.norm.running_mean -> encoder.encoders.4.conv_module.norm._mean\n",
"encoder.encoders.4.conv_module.norm.running_var -> encoder.encoders.4.conv_module.norm._variance\n",
"encoder.encoders.4.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.5.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.5.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.5.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.5.conv_module.norm.running_mean -> encoder.encoders.5.conv_module.norm._mean\n",
"encoder.encoders.5.conv_module.norm.running_var -> encoder.encoders.5.conv_module.norm._variance\n",
"encoder.encoders.5.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.6.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.6.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.6.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.6.conv_module.norm.running_mean -> encoder.encoders.6.conv_module.norm._mean\n",
"encoder.encoders.6.conv_module.norm.running_var -> encoder.encoders.6.conv_module.norm._variance\n",
"encoder.encoders.6.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.7.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.7.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.7.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.7.conv_module.norm.running_mean -> encoder.encoders.7.conv_module.norm._mean\n",
"encoder.encoders.7.conv_module.norm.running_var -> encoder.encoders.7.conv_module.norm._variance\n",
"encoder.encoders.7.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.8.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.8.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.8.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.8.conv_module.norm.running_mean -> encoder.encoders.8.conv_module.norm._mean\n",
"encoder.encoders.8.conv_module.norm.running_var -> encoder.encoders.8.conv_module.norm._variance\n",
"encoder.encoders.8.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.9.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.9.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.9.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.9.conv_module.norm.running_mean -> encoder.encoders.9.conv_module.norm._mean\n",
"encoder.encoders.9.conv_module.norm.running_var -> encoder.encoders.9.conv_module.norm._variance\n",
"encoder.encoders.9.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.10.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.10.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.10.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.10.conv_module.norm.running_mean -> encoder.encoders.10.conv_module.norm._mean\n",
"encoder.encoders.10.conv_module.norm.running_var -> encoder.encoders.10.conv_module.norm._variance\n",
"encoder.encoders.10.concat_linear.weight: (256, 512) -> (512, 256)\n",
"encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.11.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n",
"encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.11.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n",
"encoder.encoders.11.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n",
"encoder.encoders.11.conv_module.norm.running_mean -> encoder.encoders.11.conv_module.norm._mean\n",
"encoder.encoders.11.conv_module.norm.running_var -> encoder.encoders.11.conv_module.norm._variance\n",
"encoder.encoders.11.concat_linear.weight: (256, 512) -> (512, 256)\n",
"decoder.output_layer.weight: (4233, 256) -> (256, 4233)\n",
"decoder.decoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.0.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.0.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.0.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.0.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"decoder.decoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"decoder.decoders.0.concat_linear1.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.0.concat_linear2.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.1.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.1.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.1.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.1.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"decoder.decoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"decoder.decoders.1.concat_linear1.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.1.concat_linear2.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.2.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.2.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.2.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.2.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"decoder.decoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"decoder.decoders.2.concat_linear1.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.2.concat_linear2.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.3.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.3.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.3.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.3.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"decoder.decoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"decoder.decoders.3.concat_linear1.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.3.concat_linear2.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.4.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.4.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.4.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.4.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"decoder.decoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"decoder.decoders.4.concat_linear1.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.4.concat_linear2.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.5.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.5.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.5.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.5.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n",
"decoder.decoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n",
"decoder.decoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n",
"decoder.decoders.5.concat_linear1.weight: (256, 512) -> (512, 256)\n",
"decoder.decoders.5.concat_linear2.weight: (256, 512) -> (512, 256)\n",
"ctc.ctc_lo.weight: (4233, 256) -> (256, 4233)\n"
]
}
],
"source": [
"# dump torch model to paddle\n",
"import numpy as np\n",
"state_dict = model.state_dict()\n",
"paddle_state_dict = {}\n",
"\n",
"for n, p in state_dict.items():\n",
" name_change=True\n",
"\n",
" if 'norm.running_mean' in n:\n",
" new_n = n.replace('norm.running_', 'norm._')\n",
" elif 'norm.running_var' in n:\n",
" new_n = n.replace('norm.running_var', 'norm._variance')\n",
" else:\n",
" name_change=False\n",
" new_n = n\n",
" \n",
" if name_change:\n",
" print(f\"{n} -> {new_n}\")\n",
" \n",
" p = p.cpu().detach().numpy()\n",
" if n.endswith('weight') and p.ndim == 2 and 'embed.0.weight' not in n:\n",
" new_p = p.T\n",
" print(f\"{n}: {p.shape} -> {new_p.shape}\")\n",
" else:\n",
" new_p = p\n",
" \n",
" if 'global_cmvn.mean' in n:\n",
" print(p, p.dtype)\n",
" \n",
" paddle_state_dict[new_n] = new_p\n",
" \n",
"np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n",
" state=paddle_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7307dc5b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"id": "d99b29bc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(377.3326, device='cuda:0', grad_fn=<DivBackward0>)\n",
"None\n",
"[[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n",
" -5.93366381e-03 -7.26613170e-03]\n",
" [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n",
" 2.46338220e-03 2.31891591e-03]\n",
" [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n",
" 5.76929189e-03 7.48792710e-03]\n",
" ...\n",
" [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n",
" 1.16123557e-02 1.44716976e-02]\n",
" [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n",
" 8.58021621e-03 1.07796099e-02]\n",
" [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n",
" 1.60815325e-02 2.03892551e-02]]\n",
"[-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n",
" 1.0920014e-02 1.3787906e-02]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-21-622603015bd4>:6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n",
" print(loss_ctc.grad)\n"
]
}
],
"source": [
"encoder_out_lens = encoder_mask.squeeze(1).sum(1)\n",
"loss_ctc = model.ctc(encoder_out, encoder_out_lens, text, text_len)\n",
"print(loss_ctc)\n",
"dir(loss_ctc)\n",
"loss_ctc.backward()\n",
"print(loss_ctc.grad)\n",
"#print(model.ctc.ctc_lo.weight.grad)\n",
"print(model.ctc.ctc_lo.weight.grad.T.cpu().data.numpy())\n",
"print(model.ctc.ctc_lo.bias.grad.cpu().data.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "49b05d6d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(112., device='cuda:0')\n",
"tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)\n",
"tensor([False, False, False, False, False, True, True, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" True, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, False, True, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, False, False, False, True,\n",
" False, False, False, False, False, False, True, False, False, False,\n",
" False, False, True, True, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, True, True, False, False, False, False, False,\n",
" True, True], device='cuda:0')\n",
"tensor(669.4634, device='cuda:0', grad_fn=<SumBackward0>)\n",
"tensor(41.8415, device='cuda:0', grad_fn=<DivBackward0>) 0.0\n"
]
}
],
"source": [
"loss_att, acc_att = model._calc_att_loss(encoder_out, encoder_mask,\n",
" text, text_len)\n",
"print(loss_att, acc_att)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "413b413f",
"metadata": {},
"outputs": [],
"source": [
"def pad_list(xs, pad_value: int):\n",
" n_batch = len(xs)\n",
" max_len = max([x.size(0) for x in xs])\n",
" pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)\n",
" pad = pad.fill_(pad_value)\n",
" for i in range(n_batch):\n",
" pad[i, :xs[i].size(0)] = xs[i]\n",
"\n",
" return pad\n",
"\n",
"def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,\n",
" ignore_id: int):\n",
"\n",
" _sos = torch.tensor([sos],\n",
" dtype=torch.long,\n",
" requires_grad=False,\n",
" device=ys_pad.device)\n",
" _eos = torch.tensor([eos],\n",
" dtype=torch.long,\n",
" requires_grad=False,\n",
" device=ys_pad.device)\n",
" ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n",
" ys_in = [torch.cat([_sos, y], dim=0) for y in ys]\n",
" ys_out = [torch.cat([y, _eos], dim=0) for y in ys]\n",
" return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ff0c2400",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[4232, 2995, 3116, 1209, 565, 4232, 4232],\n",
" [4232, 236, 1176, 331, 66, 3925, 4077],\n",
" [4232, 2693, 524, 234, 1145, 366, 4232],\n",
" [4232, 3875, 4211, 3062, 700, 4232, 4232],\n",
" [4232, 272, 987, 1134, 494, 2959, 4232],\n",
" [4232, 1936, 3715, 120, 2553, 2695, 2710],\n",
" [4232, 25, 1149, 3930, 4232, 4232, 4232],\n",
" [4232, 1753, 1778, 1237, 482, 3925, 110],\n",
" [4232, 3703, 2, 565, 3827, 4232, 4232],\n",
" [4232, 1150, 2734, 10, 2478, 3490, 4232],\n",
" [4232, 426, 811, 95, 489, 144, 4232],\n",
" [4232, 2313, 2006, 489, 975, 4232, 4232],\n",
" [4232, 3702, 3414, 205, 1488, 2966, 1347],\n",
" [4232, 70, 1741, 702, 1666, 4232, 4232],\n",
" [4232, 703, 1778, 1030, 849, 4232, 4232],\n",
" [4232, 814, 1674, 115, 3827, 4232, 4232]], device='cuda:0')\n",
"tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n",
" [ 236, 1176, 331, 66, 3925, 4077, 4232],\n",
" [2693, 524, 234, 1145, 366, 4232, -1],\n",
" [3875, 4211, 3062, 700, 4232, -1, -1],\n",
" [ 272, 987, 1134, 494, 2959, 4232, -1],\n",
" [1936, 3715, 120, 2553, 2695, 2710, 4232],\n",
" [ 25, 1149, 3930, 4232, -1, -1, -1],\n",
" [1753, 1778, 1237, 482, 3925, 110, 4232],\n",
" [3703, 2, 565, 3827, 4232, -1, -1],\n",
" [1150, 2734, 10, 2478, 3490, 4232, -1],\n",
" [ 426, 811, 95, 489, 144, 4232, -1],\n",
" [2313, 2006, 489, 975, 4232, -1, -1],\n",
" [3702, 3414, 205, 1488, 2966, 1347, 4232],\n",
" [ 70, 1741, 702, 1666, 4232, -1, -1],\n",
" [ 703, 1778, 1030, 849, 4232, -1, -1],\n",
" [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n"
]
}
],
"source": [
"ys_pad = text\n",
"ys_pad_lens = text_len\n",
"ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n",
" model.ignore_id)\n",
"ys_in_lens = ys_pad_lens + 1\n",
"print(ys_in_pad)\n",
"print(ys_out_pad)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "3e84da38",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([16, 7, 4233])\n",
"tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n",
" 1.5035e-02, 4.0337e-01],\n",
" [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n",
" -1.4353e-01, -1.0024e+00],\n",
" [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n",
" -1.1672e+00, -2.6849e-01],\n",
" ...,\n",
" [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n",
" 4.5470e-02, -3.7139e-01],\n",
" [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n",
" -7.9347e-04, 4.2538e-01],\n",
" [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n",
" -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=<SelectBackward>)\n"
]
}
],
"source": [
"decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n",
" ys_in_lens)\n",
"print(decoder_out.shape)\n",
"print(decoder_out[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aac441ea",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 28,
"id": "5ddbca73",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.float32\n",
"torch.int64\n",
"tensor(112., device='cuda:0')\n",
"tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)\n",
"tensor([False, False, False, False, False, True, True, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" True, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, False, True, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, False, False, False, True,\n",
" False, False, False, False, False, False, True, False, False, False,\n",
" False, False, True, True, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, True, True, False, False, False, False, False,\n",
" True, True], device='cuda:0')\n",
"tensor(669.4634, device='cuda:0', grad_fn=<SumBackward0>)\n",
"tensor(41.8415, device='cuda:0', grad_fn=<DivBackward0>)\n",
"tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n",
" [ 236, 1176, 331, 66, 3925, 4077, 4232],\n",
" [2693, 524, 234, 1145, 366, 4232, -1],\n",
" [3875, 4211, 3062, 700, 4232, -1, -1],\n",
" [ 272, 987, 1134, 494, 2959, 4232, -1],\n",
" [1936, 3715, 120, 2553, 2695, 2710, 4232],\n",
" [ 25, 1149, 3930, 4232, -1, -1, -1],\n",
" [1753, 1778, 1237, 482, 3925, 110, 4232],\n",
" [3703, 2, 565, 3827, 4232, -1, -1],\n",
" [1150, 2734, 10, 2478, 3490, 4232, -1],\n",
" [ 426, 811, 95, 489, 144, 4232, -1],\n",
" [2313, 2006, 489, 975, 4232, -1, -1],\n",
" [3702, 3414, 205, 1488, 2966, 1347, 4232],\n",
" [ 70, 1741, 702, 1666, 4232, -1, -1],\n",
" [ 703, 1778, 1030, 849, 4232, -1, -1],\n",
" [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n",
"tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n",
" 1.5035e-02, 4.0337e-01],\n",
" [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n",
" -1.4353e-01, -1.0024e+00],\n",
" [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n",
" -1.1672e+00, -2.6849e-01],\n",
" ...,\n",
" [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n",
" 4.5470e-02, -3.7139e-01],\n",
" [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n",
" -7.9347e-04, 4.2538e-01],\n",
" [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n",
" -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=<SelectBackward>)\n"
]
}
],
"source": [
"print(decoder_out.dtype)\n",
"print(ys_out_pad.dtype)\n",
"loss_att = model.criterion_att(decoder_out, ys_out_pad)\n",
"print(loss_att)\n",
"print(ys_out_pad)\n",
"print(decoder_out[0])\n",
"np.savez('/workspace/DeepSpeech-2.x/.notebook/decoder',\n",
" decoder_out=decoder_out.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78f98c0b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 29,
"id": "8d968cd3",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"\n",
"class LabelSmoothingLoss(nn.Module):\n",
" def __init__(self,\n",
" size: int,\n",
" padding_idx: int,\n",
" smoothing: float,\n",
" normalize_length: bool = False):\n",
" \"\"\"Construct an LabelSmoothingLoss object.\"\"\"\n",
" super(LabelSmoothingLoss, self).__init__()\n",
" self.criterion = nn.KLDivLoss(reduction=\"none\")\n",
" self.padding_idx = padding_idx\n",
" self.confidence = 1.0 - smoothing\n",
" self.smoothing = smoothing\n",
" self.size = size\n",
" self.normalize_length = normalize_length\n",
"\n",
" def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"Compute loss between x and target.\n",
"\n",
" The model outputs and data labels tensors are flatten to\n",
" (batch*seqlen, class) shape and a mask is applied to the\n",
" padding part which should not be calculated for loss.\n",
"\n",
" Args:\n",
" x (torch.Tensor): prediction (batch, seqlen, class)\n",
" target (torch.Tensor):\n",
" target signal masked with self.padding_id (batch, seqlen)\n",
" Returns:\n",
" loss (torch.Tensor) : The KL loss, scalar float value\n",
" \"\"\"\n",
" assert x.size(2) == self.size\n",
" batch_size = x.size(0)\n",
" x = x.view(-1, self.size)\n",
" target = target.view(-1)\n",
" # use zeros_like instead of torch.no_grad() for true_dist,\n",
" # since no_grad() can not be exported by JIT\n",
" true_dist = torch.zeros_like(x)\n",
" true_dist.fill_(self.smoothing / (self.size - 1))\n",
" ignore = target == self.padding_idx # (B,)\n",
" print(self.smoothing / (self.size - 1))\n",
" print(true_dist)\n",
" total = len(target) - ignore.sum().item()\n",
" target = target.masked_fill(ignore, 0) # avoid -1 index\n",
" true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n",
" print(true_dist.dtype)\n",
" print(true_dist.square().sum())\n",
" kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)\n",
" print(kl.sum())\n",
" denom = total if self.normalize_length else batch_size\n",
" print(ignore)\n",
" numer= kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n",
" print(numer)\n",
" return numer /denom"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "3df340ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.3629489603024576e-05\n",
"tensor([[2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n",
" 2.3629e-05],\n",
" [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n",
" 2.3629e-05],\n",
" [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n",
" 2.3629e-05],\n",
" ...,\n",
" [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n",
" 2.3629e-05],\n",
" [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n",
" 2.3629e-05],\n",
" [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n",
" 2.3629e-05]], device='cuda:0')\n",
"torch.float32\n",
"tensor(90.7203, device='cuda:0')\n",
"tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)\n",
"tensor([False, False, False, False, False, True, True, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" True, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, False, True, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, False, False, False, True,\n",
" False, False, False, False, False, False, True, False, False, False,\n",
" False, False, True, True, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, False, False,\n",
" False, False, False, True, True, False, False, False, False, False,\n",
" True, True], device='cuda:0')\n",
"tensor(669.4634, device='cuda:0', grad_fn=<SumBackward0>)\n",
"tensor(41.8415, device='cuda:0', grad_fn=<DivBackward0>)\n",
"torch.int64\n"
]
}
],
"source": [
"criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n",
"loss_att = criteron(decoder_out, ys_out_pad)\n",
"print(loss_att)\n",
"print(ys_out_pad.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "badc410d",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-31-e578f72cfac5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecoder_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m allow_unreachable=True) # allow_unreachable flag\n",
"\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time."
]
}
],
"source": [
"loss_att.backward()\n",
"print(loss_att.grad)\n",
"print(decoder_out.grad)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "219eb41f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([ 0.0024, 0.0019, -0.1098, ..., 0.0028, 0.0020, -1.7978],\n",
" device='cuda:0')\n",
"tensor([[ 6.5052e-04, 6.4419e-05, -6.1955e-06, ..., 9.8220e-04,\n",
" -2.5918e-05, 3.3754e-04],\n",
" [ 3.9305e-04, 4.5799e-04, 1.4362e-04, ..., 4.6800e-04,\n",
" 1.6911e-04, 2.7067e-04],\n",
" [-1.3593e-01, 5.2201e-02, 3.2895e-02, ..., 2.4580e-02,\n",
" 1.4590e-01, -4.6850e-02],\n",
" ...,\n",
" [ 1.0434e-03, 4.2251e-04, 6.5688e-04, ..., 1.2144e-03,\n",
" 2.1159e-04, 6.6838e-04],\n",
" [ 6.4997e-04, 4.4301e-04, 4.1550e-04, ..., 1.0420e-03,\n",
" 2.4114e-04, 1.5338e-04],\n",
" [-9.9337e-01, 5.4573e-01, -1.1371e-02, ..., -4.3175e-01,\n",
" -2.7850e-01, -4.4679e-01]], device='cuda:0')\n"
]
}
],
"source": [
"print(model.decoder.output_layer.bias.grad)\n",
"print(model.decoder.output_layer.weight.grad)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "40d00a54",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[-5.3698e-01, -1.9911e-01, -3.4997e-01, ..., -8.2428e-01,\n",
" -1.0265e+00, -9.6301e-01],\n",
" [-4.4642e-02, 2.3176e-01, -3.2539e-01, ..., -9.0159e-01,\n",
" -1.0325e+00, -7.5987e-01],\n",
" [ 5.0035e-01, 2.2691e-01, -7.3052e-01, ..., -1.0055e+00,\n",
" -8.7123e-01, -1.0306e+00],\n",
" ...,\n",
" [-4.0024e-01, -1.4325e-01, -5.7947e-01, ..., -1.0718e+00,\n",
" -1.2806e+00, -1.0518e+00],\n",
" [ 1.5755e-01, -1.8495e-03, -2.8703e-01, ..., -1.1090e+00,\n",
" -9.4519e-01, -7.2506e-01],\n",
" [-4.7520e-01, -1.3942e+00, -2.5754e-01, ..., -1.1365e+00,\n",
" -1.1943e+00, -1.2290e+00]],\n",
"\n",
" [[ 9.5454e-01, 3.6428e-01, -1.3891e+00, ..., -1.1637e+00,\n",
" -1.2845e+00, -1.2015e+00],\n",
" [-8.5735e-02, -1.0579e+00, -8.9173e-01, ..., -9.6441e-01,\n",
" -1.1255e+00, -1.2599e+00],\n",
" [ 4.7654e-01, 3.2887e-01, -5.9201e-01, ..., -1.1942e+00,\n",
" -1.1430e+00, -1.0242e+00],\n",
" ...,\n",
" [-4.7431e-01, -3.3559e-01, -7.2326e-01, ..., -1.4506e+00,\n",
" -1.3957e+00, -1.0464e+00],\n",
" [ 3.6113e-01, 1.0381e-01, -1.1599e+00, ..., -1.0439e+00,\n",
" -1.0221e+00, -1.0208e+00],\n",
" [-1.2717e+00, -2.1460e+00, -7.5677e-01, ..., -9.7822e-01,\n",
" -9.3785e-01, -1.0371e+00]],\n",
"\n",
" [[-1.5465e+00, -1.0152e+00, -8.8901e-01, ..., -4.8522e-01,\n",
" -7.5163e-01, -6.7765e-01],\n",
" [-7.6101e-01, -7.3352e-01, -9.1588e-01, ..., -2.4836e-01,\n",
" -5.8927e-01, -7.3723e-01],\n",
" [-2.4714e-02, 1.7016e-01, -4.2326e-01, ..., -3.3204e-01,\n",
" -7.6696e-01, -7.1652e-01],\n",
" ...,\n",
" [-1.7032e+00, -1.2591e+00, -1.1449e+00, ..., -1.1810e+00,\n",
" -1.1163e+00, -9.3108e-01],\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00],\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[ 6.4983e-01, 2.6117e-01, -8.4197e-01, ..., -8.7213e-01,\n",
" -1.1073e+00, -1.3253e+00],\n",
" [ 3.5391e-01, -1.5846e-02, -4.0425e-01, ..., -9.9173e-01,\n",
" -1.0727e+00, -1.1924e+00],\n",
" [ 3.7704e-01, -6.2785e-02, -1.1468e-01, ..., -1.1021e+00,\n",
" -1.0952e+00, -1.1182e+00],\n",
" ...,\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00],\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00],\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00]],\n",
"\n",
" [[ 4.4458e-02, -1.7547e-01, -6.7475e-01, ..., -4.9801e-01,\n",
" -5.6783e-01, -7.7852e-01],\n",
" [-1.3428e+00, -8.0343e-01, -9.0457e-01, ..., -6.5902e-01,\n",
" -7.2550e-01, -6.2796e-01],\n",
" [-7.6253e-01, -1.3071e-01, -1.3280e-01, ..., -5.6133e-01,\n",
" -6.0588e-01, -7.2115e-01],\n",
" ...,\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00],\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00],\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00]],\n",
"\n",
" [[-1.0798e+00, -1.0834e+00, -1.1797e+00, ..., -1.7757e-01,\n",
" -4.3747e-01, -4.0007e-02],\n",
" [ 9.2354e-01, 6.3771e-01, -5.2810e-01, ..., -1.2928e-01,\n",
" -2.0342e-01, 1.6656e-01],\n",
" [ 4.9337e-01, -9.1133e-03, -7.3302e-01, ..., 1.0074e-01,\n",
" -9.8115e-02, -9.2357e-03],\n",
" ...,\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00],\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00],\n",
" [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n",
" -3.9869e+00, -3.6797e+00]]], device='cuda:0')\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"print(xs)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "505ca294",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[ True, True, True, ..., True, True, True]],\n",
"\n",
" [[ True, True, True, ..., True, True, True]],\n",
"\n",
" [[ True, True, True, ..., True, False, False]],\n",
"\n",
" ...,\n",
"\n",
" [[ True, True, True, ..., False, False, False]],\n",
"\n",
" [[ True, True, True, ..., False, False, False]],\n",
"\n",
" [[ True, True, True, ..., False, False, False]]], device='cuda:0')\n"
]
}
],
"source": [
"from wenet.utils.mask import make_pad_mask\n",
"masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n",
"print(masks)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "aa03c2b9",
"metadata": {},
"outputs": [],
"source": [
"xs, pos_emb, masks = model.encoder.embed(xs, masks)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "ebc0ea12",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[-0.5482, 2.2866, -1.0750, ..., 1.4504, 0.2895, -0.6945],\n",
" [-0.8013, 1.7688, -1.6639, ..., 1.8332, 0.6791, -0.2000],\n",
" [-1.7112, 2.7057, -1.3363, ..., 1.2336, 0.1870, -0.5735],\n",
" ...,\n",
" [-0.9697, 2.3129, -0.8752, ..., 0.8584, 0.4853, -0.4177],\n",
" [-1.3609, 2.1779, -1.7813, ..., 2.0928, 0.2528, -0.3650],\n",
" [-1.6967, 2.3544, -1.7417, ..., 1.3670, 0.5951, -0.7415]],\n",
"\n",
" [[-1.9828, 2.3178, -0.9079, ..., 0.4117, 0.5006, 0.0872],\n",
" [-0.7640, 1.3558, -1.3613, ..., 0.7317, 0.6784, 0.1685],\n",
" [-0.9504, 1.6038, -1.3030, ..., 0.5754, 0.2677, 0.3343],\n",
" ...,\n",
" [-1.4757, 2.5317, -1.2321, ..., 1.2997, 0.5019, -0.1034],\n",
" [-1.1731, 2.3172, -1.2542, ..., 1.7391, 0.2171, -0.4445],\n",
" [-1.2700, 3.2229, -0.8872, ..., 1.6461, 0.0973, -0.7679]],\n",
"\n",
" [[-0.5873, 1.4291, -1.3950, ..., 0.2102, 0.1027, 0.0918],\n",
" [ 0.1743, 1.7834, -1.6422, ..., 0.8113, 0.3137, 0.5634],\n",
" [-0.3492, 1.8310, -1.0685, ..., 0.6924, 0.1378, 0.4594],\n",
" ...,\n",
" [-1.0869, 2.3002, -1.2638, ..., 1.7998, 0.5134, -0.5223],\n",
" [-1.2614, 2.7240, -1.3734, ..., 1.4445, 0.5742, -0.3320],\n",
" [-2.2068, 4.3462, -3.8289, ..., 2.1426, 1.2034, -1.3795]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.3914, 1.8553, -0.5747, ..., 1.0062, 0.4632, -1.0452],\n",
" [-0.8605, 2.0172, -1.4437, ..., 1.4526, 0.1657, 0.5923],\n",
" [-0.7307, 2.2841, -1.0699, ..., 1.5825, -0.0980, 0.5503],\n",
" ...,\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n",
"\n",
" [[-0.1619, 0.6255, -1.1323, ..., 0.0724, -0.2204, 0.4636],\n",
" [-0.0831, 0.5750, -1.0930, ..., 0.9110, -0.0650, 0.7299],\n",
" [-0.2820, 0.0801, -0.9418, ..., 0.3379, -0.1166, 0.4451],\n",
" ...,\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n",
"\n",
" [[-0.5458, -0.6909, -1.3597, ..., -0.7818, 0.6875, 0.9843],\n",
" [ 0.0421, -1.1062, -1.4389, ..., -0.0239, 0.9115, 0.5287],\n",
" [-0.2909, -0.1886, -1.5487, ..., -0.1392, 0.0580, 0.3066],\n",
" ...,\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n",
" [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]]],\n",
" device='cuda:0', grad_fn=<MulBackward0>)\n",
"tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n",
" 0.0000e+00, 1.0000e+00],\n",
" [ 8.4147e-01, 5.4030e-01, 8.0196e-01, ..., 1.0000e+00,\n",
" 1.0746e-04, 1.0000e+00],\n",
" [ 9.0930e-01, -4.1615e-01, 9.5814e-01, ..., 1.0000e+00,\n",
" 2.1492e-04, 1.0000e+00],\n",
" ...,\n",
" [-7.6825e-01, -6.4014e-01, 6.3280e-01, ..., 9.9998e-01,\n",
" 5.1581e-03, 9.9999e-01],\n",
" [-9.5375e-01, 3.0059e-01, 9.9899e-01, ..., 9.9998e-01,\n",
" 5.2656e-03, 9.9999e-01],\n",
" [-2.6237e-01, 9.6497e-01, 5.6075e-01, ..., 9.9998e-01,\n",
" 5.3730e-03, 9.9999e-01]]], device='cuda:0')\n",
"tensor([[[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, False,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, False, False, False,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, False, False, False,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, False, False, False,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, False, False, False,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, False, False, False, False, False,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, False, False, False, False, False, False, False, False,\n",
" False]],\n",
"\n",
" [[ True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, False, False, False, False, False, False, False, False, False,\n",
" False]]], device='cuda:0')\n",
"torch.Size([16, 1, 51])\n"
]
}
],
"source": [
"print(xs)\n",
"print(pos_emb)\n",
"print(masks)\n",
"print(masks.shape)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "4289461b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n",
" -0.6945408 ]\n",
" [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n",
" -0.1999542 ]\n",
" [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n",
" -0.5735198 ]\n",
" ...\n",
" [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n",
" -0.41773027]\n",
" [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n",
" -0.36496443]\n",
" [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n",
" -0.74147725]]\n",
"\n",
" [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n",
" 0.08721463]\n",
" [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n",
" 0.16851945]\n",
" [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n",
" 0.33433008]\n",
" ...\n",
" [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n",
" -0.10343577]\n",
" [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n",
" -0.44447583]\n",
" [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n",
" -0.7678688 ]]\n",
"\n",
" [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n",
" 0.09179455]\n",
" [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n",
" 0.56344515]\n",
" [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n",
" 0.45937473]\n",
" ...\n",
" [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n",
" -0.52227837]\n",
" [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n",
" -0.33201432]\n",
" [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n",
" -1.3795122 ]]\n",
"\n",
" ...\n",
"\n",
" [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n",
" -1.045236 ]\n",
" [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n",
" 0.5923172 ]\n",
" [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n",
" 0.55030036]\n",
" ...\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]]\n",
"\n",
" [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n",
" 0.46362036]\n",
" [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n",
" 0.72986233]\n",
" [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n",
" 0.44514441]\n",
" ...\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]]\n",
"\n",
" [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n",
" 0.9842716 ]\n",
" [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n",
" 0.52870303]\n",
" [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n",
" 0.30663735]\n",
" ...\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]]]\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n",
"xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n",
"print(xs.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "67e10d73",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 2.0908e-03],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 1.1943e-02, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 4.6105e-02, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 9.6723e-03,\n",
" 4.6135e-02, 0.0000e+00]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[2.2816e-01, 2.4615e-01, 2.5304e-01, ..., 2.0402e-01,\n",
" 2.3248e-01, 3.1191e-01],\n",
" [1.3587e-01, 2.8877e-01, 2.7991e-01, ..., 1.9210e-01,\n",
" 2.0346e-01, 1.9934e-01],\n",
" [2.5739e-01, 3.9348e-01, 2.7877e-01, ..., 2.7483e-01,\n",
" 1.9302e-01, 2.3810e-01],\n",
" ...,\n",
" [1.1939e-01, 2.8473e-01, 3.3082e-01, ..., 2.3838e-01,\n",
" 2.2104e-01, 2.3906e-01],\n",
" [1.7388e-01, 2.0402e-01, 4.0263e-01, ..., 2.4782e-01,\n",
" 2.6742e-01, 1.5427e-01],\n",
" [0.0000e+00, 2.9081e-01, 2.7726e-01, ..., 1.7540e-01,\n",
" 1.8479e-01, 2.2483e-01]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[3.5447e-01, 3.8861e-01, 3.9724e-01, ..., 3.8680e-01,\n",
" 3.3568e-01, 3.4552e-01],\n",
" [4.1739e-01, 5.1039e-01, 4.1730e-01, ..., 3.3993e-01,\n",
" 3.7082e-01, 3.5110e-01],\n",
" [3.6117e-01, 4.0745e-01, 4.8491e-01, ..., 3.4849e-01,\n",
" 3.2321e-01, 3.5189e-01],\n",
" ...,\n",
" [2.3144e-01, 3.8021e-01, 5.1526e-01, ..., 3.6499e-01,\n",
" 3.7412e-01, 3.9986e-01],\n",
" [3.4679e-01, 4.0238e-01, 5.0077e-01, ..., 3.6185e-01,\n",
" 3.1597e-01, 3.6335e-01],\n",
" [3.6498e-01, 3.7943e-01, 5.1719e-01, ..., 3.1798e-01,\n",
" 3.3657e-01, 3.4130e-01]]],\n",
"\n",
"\n",
" [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[1.4560e-02, 9.4475e-02, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [1.5002e-02, 2.9632e-02, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [3.2952e-02, 0.0000e+00, 0.0000e+00, ..., 4.5850e-02,\n",
" 2.0439e-02, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 4.4258e-02],\n",
" [0.0000e+00, 0.0000e+00, 2.5565e-02, ..., 0.0000e+00,\n",
" 9.0044e-03, 4.9084e-02]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [1.1141e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[3.3697e-01, 3.8527e-01, 3.2900e-01, ..., 2.8704e-01,\n",
" 2.3351e-01, 1.9004e-01],\n",
" [1.3575e-01, 3.5783e-01, 3.3573e-01, ..., 2.2082e-01,\n",
" 1.5855e-01, 1.3587e-01],\n",
" [2.1929e-01, 2.8900e-01, 2.8255e-01, ..., 2.0603e-01,\n",
" 2.3927e-01, 2.1909e-01],\n",
" ...,\n",
" [2.3292e-01, 3.9097e-01, 3.6399e-01, ..., 2.0598e-01,\n",
" 2.5374e-01, 2.3137e-01],\n",
" [1.8739e-01, 3.0794e-01, 3.0297e-01, ..., 2.7251e-01,\n",
" 2.5192e-01, 2.0837e-01],\n",
" [2.2454e-01, 4.1402e-01, 5.4083e-01, ..., 3.1875e-01,\n",
" 2.5080e-01, 2.5939e-01]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[2.6457e-01, 4.9519e-01, 5.6702e-01, ..., 3.0955e-01,\n",
" 3.5292e-01, 3.2669e-01],\n",
" [2.1577e-01, 5.1833e-01, 4.9183e-01, ..., 3.6043e-01,\n",
" 3.8524e-01, 3.6155e-01],\n",
" [2.0068e-01, 4.2784e-01, 5.2818e-01, ..., 3.1871e-01,\n",
" 3.2452e-01, 3.1036e-01],\n",
" ...,\n",
" [4.9855e-01, 5.1001e-01, 5.2279e-01, ..., 3.6450e-01,\n",
" 3.4338e-01, 3.3603e-01],\n",
" [4.1233e-01, 5.5518e-01, 5.2828e-01, ..., 4.0676e-01,\n",
" 3.3873e-01, 3.6724e-01],\n",
" [4.0820e-01, 4.6187e-01, 4.7338e-01, ..., 3.8691e-01,\n",
" 3.6039e-01, 3.8022e-01]]],\n",
"\n",
"\n",
" [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[0.0000e+00, 5.7852e-03, 0.0000e+00, ..., 7.4838e-03,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0351e-02,\n",
" 0.0000e+00, 2.6720e-04],\n",
" [9.4807e-04, 0.0000e+00, 0.0000e+00, ..., 7.9551e-03,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [2.0326e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 1.0801e-02, 0.0000e+00],\n",
" [1.8470e-01, 0.0000e+00, 0.0000e+00, ..., 5.0584e-02,\n",
" 9.4758e-02, 5.9146e-02]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[3.8708e-01, 2.8022e-01, 3.5893e-01, ..., 1.6595e-01,\n",
" 1.6031e-01, 2.1136e-01],\n",
" [1.5595e-01, 3.0544e-01, 2.4666e-01, ..., 2.2675e-01,\n",
" 2.5765e-01, 1.9682e-01],\n",
" [2.9518e-01, 4.1210e-01, 2.0063e-01, ..., 1.7595e-01,\n",
" 2.2537e-01, 2.2214e-01],\n",
" ...,\n",
" [2.4745e-01, 2.6259e-01, 3.8654e-01, ..., 2.3620e-01,\n",
" 2.3157e-01, 1.8514e-01],\n",
" [2.5715e-01, 2.9593e-01, 4.7745e-01, ..., 2.3546e-01,\n",
" 2.5073e-01, 2.0976e-01],\n",
" [1.2015e+00, 8.4644e-01, 7.3386e-01, ..., 1.0252e+00,\n",
" 9.5310e-01, 1.0013e+00]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[4.5013e-01, 4.7484e-01, 4.0540e-01, ..., 1.9346e-01,\n",
" 1.7826e-01, 1.4777e-01],\n",
" [4.7546e-01, 4.8187e-01, 3.6760e-01, ..., 2.7809e-01,\n",
" 3.2997e-01, 3.2337e-01],\n",
" [4.6160e-01, 4.0050e-01, 3.9061e-01, ..., 3.6613e-01,\n",
" 3.5243e-01, 2.9739e-01],\n",
" ...,\n",
" [5.5148e-01, 5.1018e-01, 4.0132e-01, ..., 3.8948e-01,\n",
" 3.5737e-01, 3.3088e-01],\n",
" [4.1973e-01, 4.5475e-01, 4.5320e-01, ..., 3.8343e-01,\n",
" 4.0126e-01, 3.6181e-01],\n",
" [3.4280e-01, 3.1606e-01, 4.4701e-01, ..., 2.1665e-01,\n",
" 2.3985e-01, 2.3903e-01]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [4.1783e-02, 0.0000e+00, 1.5805e-02, ..., 0.0000e+00,\n",
" 2.2508e-02, 0.0000e+00],\n",
" [4.3234e-02, 7.7864e-02, 0.0000e+00, ..., 1.6347e-02,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[3.2092e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [1.3563e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[0.0000e+00, 2.5187e-01, 2.4979e-01, ..., 2.4775e-01,\n",
" 2.2354e-01, 1.9149e-01],\n",
" [1.6541e-01, 1.9586e-01, 1.9813e-01, ..., 2.7344e-01,\n",
" 2.0928e-01, 2.6150e-01],\n",
" [1.0495e-01, 6.3299e-02, 3.3844e-01, ..., 2.5138e-01,\n",
" 1.2470e-01, 2.3927e-01],\n",
" ...,\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00],\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00],\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[1.1428e-01, 4.5667e-01, 4.6821e-01, ..., 3.2058e-01,\n",
" 3.3579e-01, 3.9013e-01],\n",
" [1.0441e-01, 4.5739e-01, 4.6107e-01, ..., 3.8468e-01,\n",
" 3.8291e-01, 3.6686e-01],\n",
" [1.9868e-01, 3.5520e-01, 4.4313e-01, ..., 4.0679e-01,\n",
" 3.8068e-01, 3.0646e-01],\n",
" ...,\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00],\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00],\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00]]],\n",
"\n",
"\n",
" [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[2.4654e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 3.3902e-02],\n",
" [0.0000e+00, 0.0000e+00, 1.8307e-02, ..., 5.1669e-02,\n",
" 9.4838e-03, 7.4535e-02],\n",
" [9.9215e-02, 0.0000e+00, 1.5872e-02, ..., 1.6203e-02,\n",
" 5.1401e-02, 1.9239e-03],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[4.0034e-01, 2.5306e-01, 2.0218e-01, ..., 9.8162e-02,\n",
" 7.0643e-02, 4.9741e-02],\n",
" [1.2568e-01, 2.1031e-01, 1.1182e-01, ..., 4.2781e-02,\n",
" 1.1969e-01, 1.2005e-01],\n",
" [2.8787e-01, 2.4031e-01, 2.2566e-01, ..., 0.0000e+00,\n",
" 6.4181e-02, 5.8730e-02],\n",
" ...,\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00],\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00],\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[3.8405e-01, 3.0990e-01, 3.7156e-01, ..., 1.8125e-01,\n",
" 1.5051e-01, 1.9620e-01],\n",
" [4.7286e-01, 4.0529e-01, 3.9718e-01, ..., 2.4710e-01,\n",
" 4.5657e-02, 1.1501e-01],\n",
" [3.2621e-01, 3.0073e-01, 3.0477e-01, ..., 2.3529e-01,\n",
" 2.1357e-01, 1.6986e-01],\n",
" ...,\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00],\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00],\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00]]],\n",
"\n",
"\n",
" [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[3.3438e-02, 1.2378e-03, 5.2972e-02, ..., 7.2712e-02,\n",
" 8.6563e-02, 1.4494e-01],\n",
" [1.1043e-01, 6.1431e-02, 6.3630e-02, ..., 8.1278e-02,\n",
" 6.2590e-02, 8.3154e-02],\n",
" [1.7677e-02, 2.0111e-03, 7.8750e-02, ..., 6.9633e-02,\n",
" 8.9799e-02, 5.3263e-02],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[1.0034e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [1.5627e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [5.1447e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 4.3641e-03],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" ...,\n",
"\n",
" [[2.5142e-01, 4.5964e-01, 3.7346e-01, ..., 4.7631e-02,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [1.9760e-01, 2.6627e-01, 1.1191e-01, ..., 3.0450e-02,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [1.6341e-01, 3.2938e-01, 2.5690e-01, ..., 5.5694e-02,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00],\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00],\n",
" [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n",
" 1.0094e+00, 1.0221e+00]],\n",
"\n",
" [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 2.2189e-02, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 2.8490e-02],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
" 0.0000e+00, 0.0000e+00]],\n",
"\n",
" [[2.5810e-01, 6.3017e-01, 3.7038e-01, ..., 1.8704e-01,\n",
" 8.2694e-02, 9.9127e-02],\n",
" [1.7293e-01, 5.0679e-01, 4.0739e-01, ..., 1.6006e-01,\n",
" 1.1725e-01, 9.9405e-02],\n",
" [2.4175e-01, 4.1616e-01, 4.1257e-01, ..., 1.3520e-01,\n",
" 7.9126e-02, 1.2846e-01],\n",
" ...,\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00],\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00],\n",
" [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n",
" 1.2189e+00, 1.2380e+00]]]], device='cuda:0',\n",
" grad_fn=<ReluBackward0>)\n"
]
}
],
"source": [
"xs = model.encoder.global_cmvn(feat)\n",
"masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n",
"\n",
"x = xs.unsqueeze(1)\n",
"x = model.encoder.embed.conv(x)\n",
"print(x)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "9a9478ad",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[-0.03426375 0.14291267 -0.06718873 ... 0.09064753 0.01809387\n",
" -0.0434088 ]\n",
" [-0.05007839 0.11054724 -0.10399298 ... 0.11457238 0.04244684\n",
" -0.01249714]\n",
" [-0.10695291 0.16910909 -0.08352133 ... 0.07710276 0.01168563\n",
" -0.03584499]\n",
" ...\n",
" [-0.06060536 0.14455931 -0.05470302 ... 0.05364908 0.03033342\n",
" -0.02610814]\n",
" [-0.08505894 0.13611752 -0.11132983 ... 0.13079923 0.01580139\n",
" -0.02281028]\n",
" [-0.10604677 0.14714901 -0.10885533 ... 0.08543444 0.03719445\n",
" -0.04634233]]\n",
"\n",
" [[-0.12392755 0.14486063 -0.05674079 ... 0.02573164 0.03128851\n",
" 0.00545091]\n",
" [-0.04775286 0.08473608 -0.08507854 ... 0.04573154 0.04240163\n",
" 0.01053247]\n",
" [-0.05940291 0.10023535 -0.0814373 ... 0.035965 0.01673085\n",
" 0.02089563]\n",
" ...\n",
" [-0.09222981 0.15823206 -0.07700447 ... 0.08122957 0.03136991\n",
" -0.00646474]\n",
" [-0.07331756 0.14482647 -0.07838815 ... 0.1086944 0.01356864\n",
" -0.02777974]\n",
" [-0.07937264 0.20143102 -0.05544947 ... 0.10287814 0.00608235\n",
" -0.0479918 ]]\n",
"\n",
" [[-0.03670349 0.0893159 -0.08718812 ... 0.0131405 0.00642052\n",
" 0.00573716]\n",
" [ 0.01089254 0.11146393 -0.10263617 ... 0.05070438 0.01960694\n",
" 0.03521532]\n",
" [-0.0218228 0.11443964 -0.06678198 ... 0.04327708 0.00861394\n",
" 0.02871092]\n",
" ...\n",
" [-0.06792898 0.14376275 -0.07899005 ... 0.11248926 0.03208683\n",
" -0.0326424 ]\n",
" [-0.07884051 0.17024788 -0.08583611 ... 0.09028331 0.03588808\n",
" -0.0207509 ]\n",
" [-0.13792302 0.27163863 -0.23930418 ... 0.13391261 0.0752104\n",
" -0.08621951]]\n",
"\n",
" ...\n",
"\n",
" [[-0.02446348 0.11595841 -0.03591986 ... 0.0628897 0.02895011\n",
" -0.06532725]\n",
" [-0.05378424 0.1260737 -0.09023033 ... 0.09078894 0.01035743\n",
" 0.03701983]\n",
" [-0.04566649 0.14275314 -0.0668687 ... 0.09890588 -0.00612222\n",
" 0.03439377]\n",
" ...\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]]\n",
"\n",
" [[-0.01012144 0.03909408 -0.07077143 ... 0.00452683 -0.01377654\n",
" 0.02897627]\n",
" [-0.00519154 0.03594019 -0.06831125 ... 0.05693541 -0.00406374\n",
" 0.0456164 ]\n",
" [-0.01762631 0.00500899 -0.05886075 ... 0.02112178 -0.00729015\n",
" 0.02782153]\n",
" ...\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]]\n",
"\n",
" [[-0.03411558 -0.04318277 -0.08497842 ... -0.04886402 0.04296734\n",
" 0.06151697]\n",
" [ 0.00263296 -0.06913657 -0.08993219 ... -0.00149064 0.05696633\n",
" 0.03304394]\n",
" [-0.01818341 -0.0117864 -0.09679577 ... -0.00870231 0.00362198\n",
" 0.01916483]\n",
" ...\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]\n",
" [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n",
" -0.18293698]]]\n",
"torch.Size([16, 51, 256])\n"
]
}
],
"source": [
"b, c, t, f = x.size()\n",
"x = model.encoder.embed.out(x.transpose(1, 2).contiguous().view(b, t, c * f))\n",
"print(x.cpu().detach().numpy())\n",
"print(x.shape)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "fd69003f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n",
" -0.6945408 ]\n",
" [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n",
" -0.1999542 ]\n",
" [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n",
" -0.5735198 ]\n",
" ...\n",
" [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n",
" -0.41773027]\n",
" [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n",
" -0.36496443]\n",
" [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n",
" -0.74147725]]\n",
"\n",
" [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n",
" 0.08721463]\n",
" [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n",
" 0.16851945]\n",
" [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n",
" 0.33433008]\n",
" ...\n",
" [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n",
" -0.10343577]\n",
" [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n",
" -0.44447583]\n",
" [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n",
" -0.7678688 ]]\n",
"\n",
" [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n",
" 0.09179455]\n",
" [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n",
" 0.56344515]\n",
" [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n",
" 0.45937473]\n",
" ...\n",
" [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n",
" -0.52227837]\n",
" [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n",
" -0.33201432]\n",
" [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n",
" -1.3795122 ]]\n",
"\n",
" ...\n",
"\n",
" [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n",
" -1.045236 ]\n",
" [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n",
" 0.5923172 ]\n",
" [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n",
" 0.55030036]\n",
" ...\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]]\n",
"\n",
" [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n",
" 0.46362036]\n",
" [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n",
" 0.72986233]\n",
" [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n",
" 0.44514441]\n",
" ...\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]]\n",
"\n",
" [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n",
" 0.9842716 ]\n",
" [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n",
" 0.52870303]\n",
" [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n",
" 0.30663735]\n",
" ...\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]\n",
" [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n",
" -2.9269917 ]]]\n"
]
}
],
"source": [
"x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n",
"print(x.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "8ed88489",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.float32\n",
"[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n",
" 0.0000000e+00 1.0000000e+00]\n",
" [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n",
" 1.0746076e-04 1.0000000e+00]\n",
" [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n",
" 2.1492151e-04 1.0000000e+00]\n",
" ...\n",
" [-7.6825464e-01 -6.4014435e-01 6.3279724e-01 ... 9.9998462e-01\n",
" 5.1580933e-03 9.9998671e-01]\n",
" [-9.5375264e-01 3.0059254e-01 9.9899054e-01 ... 9.9998397e-01\n",
" 5.2655530e-03 9.9998611e-01]\n",
" [-2.6237485e-01 9.6496606e-01 5.6074661e-01 ... 9.9998331e-01\n",
" 5.3730118e-03 9.9998558e-01]]]\n"
]
}
],
"source": [
"print(pos_emb.dtype)\n",
"print(pos_emb.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "5e277881",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([16, 51, 256])\n"
]
},
{
"ename": "NameError",
"evalue": "name 'mask' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-54-bcc891cacca8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mpos_emb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpos_emb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0mx_att\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m )\n",
"\u001b[0;31mNameError\u001b[0m: name 'mask' is not defined"
]
}
],
"source": [
"def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,\n",
" use_dynamic_chunk: bool,\n",
" use_dynamic_left_chunk: bool,\n",
" decoding_chunk_size: int, static_chunk_size: int,\n",
" num_decoding_left_chunks: int):\n",
" \"\"\" Apply optional mask for encoder.\n",
" Args:\n",
" xs (torch.Tensor): padded input, (B, L, D), L for max length\n",
" mask (torch.Tensor): mask for xs, (B, 1, L)\n",
" use_dynamic_chunk (bool): whether to use dynamic chunk or not\n",
" use_dynamic_left_chunk (bool): whether to use dynamic left chunk for\n",
" training.\n",
" decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's\n",
" 0: default for training, use random dynamic chunk.\n",
" <0: for decoding, use full chunk.\n",
" >0: for decoding, use fixed chunk size as set.\n",
" static_chunk_size (int): chunk size for static chunk training/decoding\n",
" if it's greater than 0, if use_dynamic_chunk is true,\n",
" this parameter will be ignored\n",
" num_decoding_left_chunks: number of left chunks, this is for decoding,\n",
" the chunk size is decoding_chunk_size.\n",
" >=0: use num_decoding_left_chunks\n",
" <0: use all left chunks\n",
" Returns:\n",
" torch.Tensor: chunk mask of the input xs.\n",
" \"\"\"\n",
" # Whether to use chunk mask or not\n",
" if use_dynamic_chunk:\n",
" max_len = xs.size(1)\n",
" if decoding_chunk_size < 0:\n",
" chunk_size = max_len\n",
" num_left_chunks = -1\n",
" elif decoding_chunk_size > 0:\n",
" chunk_size = decoding_chunk_size\n",
" num_left_chunks = num_decoding_left_chunks\n",
" else:\n",
" # chunk size is either [1, 25] or full context(max_len).\n",
" # Since we use 4 times subsampling and allow up to 1s(100 frames)\n",
" # delay, the maximum frame is 100 / 4 = 25.\n",
" chunk_size = torch.randint(1, max_len, (1, )).item()\n",
" num_left_chunks = -1\n",
" if chunk_size > max_len // 2:\n",
" chunk_size = max_len\n",
" else:\n",
" chunk_size = chunk_size % 25 + 1\n",
" if use_dynamic_left_chunk:\n",
" max_left_chunks = (max_len - 1) // chunk_size\n",
" num_left_chunks = torch.randint(0, max_left_chunks,\n",
" (1, )).item()\n",
" chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,\n",
" num_left_chunks,\n",
" xs.device) # (L, L)\n",
" chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n",
" chunk_masks = masks & chunk_masks # (B, L, L)\n",
" elif static_chunk_size > 0:\n",
" num_left_chunks = num_decoding_left_chunks\n",
" chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,\n",
" num_left_chunks,\n",
" xs.device) # (L, L)\n",
" chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n",
" chunk_masks = masks & chunk_masks # (B, L, L)\n",
" else:\n",
" chunk_masks = masks\n",
" return chunk_masks\n",
"\n",
"from wenet.utils.mask import make_pad_mask\n",
"\n",
"\n",
"masks = ~make_pad_mask(feat_len).unsqueeze(1)\n",
"xs = model.encoder.global_cmvn(feat)\n",
"xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n",
"\n",
"mask_pad = masks\n",
"decoding_chunk_size=0\n",
"num_decoding_left_chunks=-1\n",
"use_dynamic_left_chunk=-1\n",
"use_dynamic_chunk=False\n",
"static_chunk_size=-1\n",
"chunk_masks = add_optional_chunk_mask(\n",
" xs, \n",
" masks, \n",
" use_dynamic_chunk,\n",
" use_dynamic_left_chunk,\n",
" decoding_chunk_size, \n",
" static_chunk_size,\n",
" num_decoding_left_chunks)\n",
"\n",
"np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_embed', \n",
" embed_out=xs.cpu().detach().numpy(), \n",
" pos_emb=pos_emb.cpu().detach().numpy(),\n",
" chunk_masks=chunk_masks.cpu().detach().numpy(),\n",
" mask_pad=mask_pad.cpu().detach().numpy())\n",
"\n",
"model.eval()\n",
"# print(chunk_masks)\n",
"print(xs.shape)\n",
"for layer in model.encoder.encoders:\n",
" #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
" #np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0', enc_0=xs.cpu().detach().numpy())\n",
" \n",
" x = xs\n",
" residual = x\n",
" x_norm = layer.norm_ff_macaron(x)\n",
" !rm /workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff.npz\n",
" np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff', \n",
" norm_ff=x_norm.cpu().detach().numpy(),\n",
" xs=xs.cpu().detach().numpy())\n",
" #print(x.cpu().detach().numpy())\n",
" for p in layer.norm_ff_macaron.parameters():\n",
" #print(p, p.sum())\n",
" pass\n",
" \n",
" x = residual + layer.ff_scale * layer.feed_forward_macaron(x_norm)\n",
" \n",
" ps = []\n",
" for n, p in layer.feed_forward_macaron.state_dict().items():\n",
" #print(n, p.cpu().data.numpy())\n",
" ps.append(p.cpu().data.numpy())\n",
" pass\n",
"\n",
" ff_l_x = layer.feed_forward_macaron.w_1(x_norm)\n",
" ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n",
" ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n",
" np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_ff_out', \n",
" norm_ff=x_norm.cpu().detach().numpy(),\n",
" ff_out=x.cpu().detach().numpy(),\n",
" ff_l_x = ff_l_x.cpu().detach().numpy(),\n",
" ff_l_a_x=ff_l_a_x.cpu().detach().numpy(),\n",
" ff_l_a_l_x=ff_l_a_l_x.cpu().detach().numpy(),\n",
" ps=ps,\n",
" )\n",
" \n",
" \n",
" residual = x\n",
" x = layer.norm_mha(x)\n",
" x_q = x\n",
" \n",
" x_att = layer.self_attn(x_q, x, x, pos_emb, masks)\n",
" np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_selattn_out', \n",
" x_q=x_q.cpu().detach().numpy(),\n",
" x=x.cpu().detach().numpy(),\n",
" pos_emb = pos_emb.cpu().detach().numpy(),\n",
" mask=mask.cpu().detach().numpy(),\n",
" x_att=x_att.cpu().detach().numpy(),\n",
" )\n",
" \n",
" break\n",
"#print(xs.cpu().detach().numpy())\n",
"\n",
"\n",
"i = 0\n",
"for layer in model.encoder.encoders:\n",
" xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n",
" i += 1\n",
" if i == 2:\n",
" np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_2', enc_2=xs.cpu().detach().numpy())\n",
" \n",
"np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_all', enc_all=xs.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c43fd4f1",
"metadata": {},
"outputs": [],
"source": [
"out, mask = model.encoder(feat, feat_len)\n",
"#print(out.cpu().detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e73db22",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f506114",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}