{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cfb832c0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/workspace/wenet\n" ] }, { "data": { "text/plain": [ "'/workspace/wenet'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%cd /workspace/wenet/\n", "%pwd" ] }, { "cell_type": "code", "execution_count": 2, "id": "62277538", "metadata": {}, "outputs": [], "source": [ "\n", "import argparse\n", "import copy\n", "import logging\n", "import os\n", "\n", "import torch\n", "import torch.distributed as dist\n", "import torch.optim as optim\n", "import yaml\n", "from tensorboardX import SummaryWriter\n", "from torch.utils.data import DataLoader\n", "\n", "from wenet.dataset.dataset import AudioDataset, CollateFunc\n", "from wenet.transformer.asr_model import init_asr_model\n", "from wenet.utils.checkpoint import load_checkpoint, save_checkpoint\n", "from wenet.utils.executor import Executor\n", "from wenet.utils.scheduler import WarmupLR\n", "\n", "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "2f6ea33a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'config': 'examples/aishell/s0/conf/train_conformer.yaml', 'train_data': 'examples/aishell/s0/raw_wav/train/format.data', 'cv_data': 'examples/aishell/s0/raw_wav/dev/format.data', 'gpu': -1, 'model_dir': None, 'checkpoint': None, 'tensorboard_dir': 'tensorboard', 'rank': 0, 'world_size': -1, 'dist_backend': 'nccl', 'init_method': None, 'num_workers': 0, 'pin_memory': False, 'cmvn': 'examples/aishell/s0/raw_wav/train/global_cmvn'}\n" ] } ], "source": [ "parser = argparse.ArgumentParser(description='training your network')\n", "parser.add_argument('--config', default=\"examples/aishell/s0/conf/train_conformer.yaml\", help='config file')\n", "parser.add_argument('--train_data', default=\"examples/aishell/s0/raw_wav/train/format.data\", help='train data file')\n", "parser.add_argument('--cv_data', default=\"examples/aishell/s0/raw_wav/dev/format.data\", help='cv data file')\n", "parser.add_argument('--gpu',\n", " type=int,\n", " default=-1,\n", " help='gpu id for this local rank, -1 for cpu')\n", "parser.add_argument('--model_dir' , help='save model dir')\n", "parser.add_argument('--checkpoint', help='checkpoint model')\n", "parser.add_argument('--tensorboard_dir',\n", " default='tensorboard',\n", " help='tensorboard log dir')\n", "parser.add_argument('--ddp.rank',\n", " dest='rank',\n", " default=0,\n", " type=int,\n", " help='global rank for distributed training')\n", "parser.add_argument('--ddp.world_size',\n", " dest='world_size',\n", " default=-1,\n", " type=int,\n", " help='''number of total processes/gpus for\n", " distributed training''')\n", "parser.add_argument('--ddp.dist_backend',\n", " dest='dist_backend',\n", " default='nccl',\n", " choices=['nccl', 'gloo'],\n", " help='distributed backend')\n", "parser.add_argument('--ddp.init_method',\n", " dest='init_method',\n", " default=None,\n", " help='ddp init method')\n", "parser.add_argument('--num_workers',\n", " default=0,\n", " type=int,\n", " help='num of subprocess workers for reading')\n", "parser.add_argument('--pin_memory',\n", " action='store_true',\n", " default=False,\n", " help='Use pinned memory buffers used for reading')\n", "parser.add_argument('--cmvn', default=\"examples/aishell/s0/raw_wav/train/global_cmvn\", help='global cmvn file')\n", "\n", "args = parser.parse_args([])\n", "print(vars(args))" ] }, { "cell_type": "code", "execution_count": 4, "id": "f5d6af9b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Namespace(checkpoint=None, cmvn='examples/aishell/s0/raw_wav/train/global_cmvn', config='examples/aishell/s0/conf/train_conformer.yaml', cv_data='examples/aishell/s0/raw_wav/dev/format.data', dist_backend='nccl', gpu=-1, init_method=None, model_dir=None, num_workers=0, pin_memory=False, rank=0, tensorboard_dir='tensorboard', train_data='examples/aishell/s0/raw_wav/train/format.data', world_size=-1)\n" ] } ], "source": [ "# Set random seed\n", "torch.manual_seed(777)\n", "print(args)\n", "with open(args.config, 'r') as fin:\n", " configs = yaml.load(fin, Loader=yaml.FullLoader)" ] }, { "cell_type": "code", "execution_count": 5, "id": "264bd353", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7507 batches\n", "896\n" ] } ], "source": [ "raw_wav = configs['raw_wav']\n", "\n", "train_collate_func = CollateFunc(**configs['collate_conf'],\n", " raw_wav=raw_wav)\n", "\n", "cv_collate_conf = copy.deepcopy(configs['collate_conf'])\n", "# no augmenation on cv set\n", "cv_collate_conf['spec_aug'] = False\n", "cv_collate_conf['spec_sub'] = False\n", "if raw_wav:\n", " cv_collate_conf['feature_dither'] = 0.0\n", " cv_collate_conf['speed_perturb'] = False\n", " cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0\n", "cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)\n", "\n", "dataset_conf = configs.get('dataset_conf', {})\n", "train_dataset = AudioDataset(args.train_data,\n", " **dataset_conf,\n", " raw_wav=raw_wav)\n", "cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)\n", "# 120098 data/train/wav.scp\n", "print(len(train_dataset), 'batches')\n", "# 14326 data/dev/wav.scp\n", "print(len(cv_dataset))" ] }, { "cell_type": "code", "execution_count": 6, "id": "88863d3c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "896\n" ] } ], "source": [ "train_sampler = None\n", "cv_sampler = None\n", "train_data_loader = DataLoader(train_dataset,\n", " collate_fn=train_collate_func,\n", " sampler=train_sampler,\n", " #shuffle=(train_sampler is None),\n", " shuffle=False,\n", " pin_memory=args.pin_memory,\n", " batch_size=1,\n", " num_workers=args.num_workers)\n", "cv_data_loader = DataLoader(cv_dataset,\n", " collate_fn=cv_collate_func,\n", " sampler=cv_sampler,\n", " shuffle=False,\n", " batch_size=1,\n", " pin_memory=args.pin_memory,\n", " num_workers=args.num_workers)\n", "print(len(cv_data_loader))" ] }, { "cell_type": "code", "execution_count": 7, "id": "10d5acd4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4233 vocab\n", "80 feat dim\n" ] } ], "source": [ "if raw_wav:\n", " input_dim = configs['collate_conf']['feature_extraction_conf'][\n", " 'mel_bins']\n", "else:\n", " input_dim = train_dataset.input_dim\n", "vocab_size = train_dataset.output_dim\n", "print(vocab_size, 'vocab')\n", "print(input_dim , 'feat dim')" ] }, { "cell_type": "code", "execution_count": 8, "id": "0380ef5a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "examples/aishell/s0/raw_wav/train/global_cmvn\n" ] } ], "source": [ "# Save configs to model_dir/train.yaml for inference and export\n", "configs['input_dim'] = input_dim\n", "configs['output_dim'] = vocab_size\n", "configs['cmvn_file'] = args.cmvn\n", "configs['is_json_cmvn'] = raw_wav\n", "print(args.cmvn)\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "15ebf2bf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(80,)\n", "(80,)\n", "[ 9.87176362 9.93891555 10.23818678 10.85971412 11.68652649 12.2548801\n", " 12.65768161 12.86138996 12.80733912 12.56625574 12.32007066 12.13879205\n", " 12.31318868 12.55255216 12.61223855 12.56974526 12.38972728 12.14383338\n", " 12.09285066 11.79395822 11.62259065 11.9263303 11.8154422 11.95122567\n", " 11.83180553 11.88788759 11.79014437 11.88072035 11.90005711 11.97348142\n", " 12.00982189 12.00881339 12.02619706 12.10479646 12.21555081 12.34399304\n", " 12.45014401 12.4966879 12.48653775 12.3550783 12.39291732 12.2553737\n", " 12.26496277 12.25314244 12.32545763 12.43359839 12.54867439 12.6763342\n", " 12.80920698 12.92934681 12.96115138 12.96883353 12.99593057 13.04728142\n", " 13.0588804 13.05737948 12.99921175 12.93402238 12.87429219 12.71652995\n", " 12.48942004 12.27478385 12.26163069 12.28631891 12.31956049 12.4229073\n", " 12.51480191 12.5785164 12.64719411 12.73762568 12.80017069 12.86872766\n", " 12.96666856 13.06478583 13.15915908 13.27284306 13.31081821 13.23904279\n", " 12.87936075 11.18310185]\n", "[0.61219383 0.49700994 0.33439025 0.31503119 0.29640823 0.28411759\n", " 0.26972922 0.25610475 0.24632936 0.24610228 0.24733299 0.24426536\n", " 0.23751781 0.22987273 0.22659963 0.2268427 0.23059031 0.23420722\n", " 0.23771761 0.2411352 0.24404673 0.24557175 0.24724932 0.25055198\n", " 0.25482755 0.2602407 0.26363878 0.26503898 0.2648467 0.26435072\n", " 0.26353625 0.26364794 0.26411054 0.26339948 0.26212082 0.26146597\n", " 0.26196556 0.26365859 0.26592959 0.26963884 0.27392766 0.27818809\n", " 0.28313664 0.2863325 0.28713431 0.28649323 0.28636648 0.2867843\n", " 0.28635904 0.28562022 0.28492711 0.28429201 0.28402977 0.28401045\n", " 0.28560797 0.28728033 0.28969549 0.29351627 0.29826453 0.30572631\n", " 0.31811682 0.32887739 0.33288219 0.33326245 0.33014147 0.32403202\n", " 0.31903576 0.31316258 0.30741037 0.30370692 0.30204833 0.30049064\n", " 0.29901079 0.29824511 0.29812308 0.29753329 0.29779342 0.30175296\n", " 0.30955538 0.32904205]\n" ] } ], "source": [ "import json\n", "import math\n", "import numpy as np\n", "def _load_json_cmvn(json_cmvn_file):\n", " \"\"\" Load the json format cmvn stats file and calculate cmvn\n", "\n", " Args:\n", " json_cmvn_file: cmvn stats file in json format\n", "\n", " Returns:\n", " a numpy array of [means, vars]\n", " \"\"\"\n", " with open(json_cmvn_file) as f:\n", " cmvn_stats = json.load(f)\n", "\n", " means = cmvn_stats['mean_stat']\n", " variance = cmvn_stats['var_stat']\n", " count = cmvn_stats['frame_num']\n", " for i in range(len(means)):\n", " means[i] /= count\n", " variance[i] = variance[i] / count - means[i] * means[i]\n", " if variance[i] < 1.0e-20:\n", " variance[i] = 1.0e-20\n", " variance[i] = 1.0 / math.sqrt(variance[i])\n", " cmvn = np.array([means, variance])\n", " return cmvn\n", "\n", "\n", "def _load_kaldi_cmvn(kaldi_cmvn_file):\n", " \"\"\" Load the kaldi format cmvn stats file and calculate cmvn\n", "\n", " Args:\n", " kaldi_cmvn_file: kaldi text style global cmvn file, which\n", " is generated by:\n", " compute-cmvn-stats --binary=false scp:feats.scp global_cmvn\n", "\n", " Returns:\n", " a numpy array of [means, vars]\n", " \"\"\"\n", " means = []\n", " variance = []\n", " with open(kaldi_cmvn_file, 'r') as fid:\n", " # kaldi binary file start with '\\0B'\n", " if fid.read(2) == '\\0B':\n", " logger.error('kaldi cmvn binary file is not supported, please '\n", " 'recompute it by: compute-cmvn-stats --binary=false '\n", " ' scp:feats.scp global_cmvn')\n", " sys.exit(1)\n", " fid.seek(0)\n", " arr = fid.read().split()\n", " assert (arr[0] == '[')\n", " assert (arr[-2] == '0')\n", " assert (arr[-1] == ']')\n", " feat_dim = int((len(arr) - 2 - 2) / 2)\n", " for i in range(1, feat_dim + 1):\n", " means.append(float(arr[i]))\n", " count = float(arr[feat_dim + 1])\n", " for i in range(feat_dim + 2, 2 * feat_dim + 2):\n", " variance.append(float(arr[i]))\n", "\n", " for i in range(len(means)):\n", " means[i] /= count\n", " variance[i] = variance[i] / count - means[i] * means[i]\n", " if variance[i] < 1.0e-20:\n", " variance[i] = 1.0e-20\n", " variance[i] = 1.0 / math.sqrt(variance[i])\n", " cmvn = np.array([means, variance])\n", " return cmvn\n", "\n", "\n", "def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):\n", " npzfile = np.load(npz_cmvn_file)\n", " means = npzfile[\"mean\"] #(1, D)\n", " std = npzfile[\"std\"] #(1, D)\n", " std = np.clip(std, eps, None)\n", " variance = 1.0 / std\n", " cmvn = np.array([means, variance])\n", " return cmvn\n", "\n", "\n", "def load_cmvn(cmvn_file: str, filetype: str):\n", " \"\"\"load cmvn from file.\n", "\n", " Args:\n", " cmvn_file (str): cmvn path.\n", " filetype (str): file type, optional[npz, json, kaldi].\n", "\n", " Raises:\n", " ValueError: file type not support.\n", "\n", " Returns:\n", " Tuple[np.ndarray, np.ndarray]: mean, istd\n", " \"\"\"\n", " assert filetype in ['npz', 'json', 'kaldi'], filetype\n", " filetype = filetype.lower()\n", " if filetype == \"json\":\n", " cmvn = _load_json_cmvn(cmvn_file)\n", " elif filetype == \"kaldi\":\n", " cmvn = _load_kaldi_cmvn(cmvn_file)\n", " elif filetype == \"npz\":\n", " cmvn = _load_npz_cmvn(cmvn_file)\n", " else:\n", " raise ValueError(f\"cmvn file type no support: {filetype}\")\n", " return cmvn[0], cmvn[1]\n", "\n", "mean, istd = load_cmvn(args.cmvn, 'json')\n", "print(mean.shape)\n", "print(istd.shape)\n", "print(mean)\n", "print(istd)" ] }, { "cell_type": "code", "execution_count": 10, "id": "3cfa5e23", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ASRModel(\n", " (encoder): ConformerEncoder(\n", " (global_cmvn): GlobalCMVN()\n", " (embed): Conv2dSubsampling4(\n", " (conv): Sequential(\n", " (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))\n", " (1): ReLU()\n", " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", " (3): ReLU()\n", " )\n", " (out): Sequential(\n", " (0): Linear(in_features=4864, out_features=256, bias=True)\n", " )\n", " (pos_enc): RelPositionalEncoding(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (encoders): ModuleList(\n", " (0): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (1): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (2): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (3): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (4): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (5): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (6): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (7): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (8): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (9): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (10): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (11): ConformerEncoderLayer(\n", " (self_attn): RelPositionMultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (feed_forward_macaron): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): Swish()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (conv_module): ConvolutionModule(\n", " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", " (activation): Swish()\n", " )\n", " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " )\n", " )\n", " (decoder): TransformerDecoder(\n", " (embed): Sequential(\n", " (0): Embedding(4233, 256)\n", " (1): PositionalEncoding(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (output_layer): Linear(in_features=256, out_features=4233, bias=True)\n", " (decoders): ModuleList(\n", " (0): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (1): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (2): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (3): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (4): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " (5): DecoderLayer(\n", " (self_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (src_attn): MultiHeadedAttention(\n", " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (feed_forward): PositionwiseFeedForward(\n", " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", " )\n", " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", " )\n", " )\n", " )\n", " (ctc): CTC(\n", " (ctc_lo): Linear(in_features=256, out_features=4233, bias=True)\n", " (ctc_loss): CTCLoss()\n", " )\n", " (criterion_att): LabelSmoothingLoss(\n", " (criterion): KLDivLoss()\n", " )\n", ")\n" ] } ], "source": [ "# Init asr model from configs\n", "model = init_asr_model(configs)\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 11, "id": "3c780af5", "metadata": {}, "outputs": [], "source": [ "\n", "def summary(layer, print_func=print):\n", " num_params = num_elements = 0\n", " for name, param in layer.state_dict().items():\n", " if print_func:\n", " print_func(\n", " \"{} | {} | {}\".format(name, param.shape, np.prod(param.shape)))\n", " num_elements += np.prod(param.shape)\n", " num_params += 1\n", " if print_func:\n", " print_func(\n", " f\"Total parameters: {num_params}, {num_elements} elements.\"\n", " )\n", " \n", "def print_params(model, print_func=print):\n", " if print_func is None:\n", " return\n", " total = 0.0\n", " num_params = 0.0\n", " for n, p in model.named_parameters():\n", " msg = f\"{n} | {p.shape} | {np.prod(p.shape)} | {p.requires_grad}\"\n", " total += np.prod(p.shape)\n", " num_params += 1\n", " if print_func:\n", " print_func(msg)\n", " if print_func:\n", " print_func(f\"Total parameters: {num_params}, {total} elements.\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "e159a200", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "encoder.global_cmvn.mean | torch.Size([80]) | 80\n", "encoder.global_cmvn.istd | torch.Size([80]) | 80\n", "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304\n", "encoder.embed.conv.0.bias | torch.Size([256]) | 256\n", "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824\n", "encoder.embed.conv.2.bias | torch.Size([256]) | 256\n", "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184\n", "encoder.embed.out.0.bias | torch.Size([256]) | 256\n", "encoder.after_norm.weight | torch.Size([256]) | 256\n", "encoder.after_norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.0.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.0.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.1.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.1.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.2.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.2.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.3.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.3.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.4.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.4.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.5.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.5.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.6.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.6.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.7.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.7.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.8.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.8.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.9.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.9.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.10.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.10.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256\n", "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.conv_module.norm.running_mean | torch.Size([256]) | 256\n", "encoder.encoders.11.conv_module.norm.running_var | torch.Size([256]) | 256\n", "encoder.encoders.11.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256\n", "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256\n", "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072\n", "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256\n", "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648\n", "decoder.after_norm.weight | torch.Size([256]) | 256\n", "decoder.after_norm.bias | torch.Size([256]) | 256\n", "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648\n", "decoder.output_layer.bias | torch.Size([4233]) | 4233\n", "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256\n", "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256\n", "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256\n", "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256\n", "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256\n", "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256\n", "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256\n", "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256\n", "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256\n", "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256\n", "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256\n", "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256\n", "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256\n", "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256\n", "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256\n", "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256\n", "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256\n", "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256\n", "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256\n", "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256\n", "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256\n", "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256\n", "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256\n", "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256\n", "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256\n", "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648\n", "ctc.ctc_lo.bias | torch.Size([4233]) | 4233\n", "Total parameters: 701, 49355454.0 elements.\n" ] } ], "source": [ "summary(model)" ] }, { "cell_type": "code", "execution_count": null, "id": "8494c6ab", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 13, "id": "0648a969", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304 | True\n", "encoder.embed.conv.0.bias | torch.Size([256]) | 256 | True\n", "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824 | True\n", "encoder.embed.conv.2.bias | torch.Size([256]) | 256 | True\n", "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184 | True\n", "encoder.embed.out.0.bias | torch.Size([256]) | 256 | True\n", "encoder.after_norm.weight | torch.Size([256]) | 256 | True\n", "encoder.after_norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256 | True\n", "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256 | True\n", "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648 | True\n", "decoder.after_norm.weight | torch.Size([256]) | 256 | True\n", "decoder.after_norm.bias | torch.Size([256]) | 256 | True\n", "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648 | True\n", "decoder.output_layer.bias | torch.Size([4233]) | 4233 | True\n", "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256 | True\n", "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256 | True\n", "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648 | True\n", "ctc.ctc_lo.bias | torch.Size([4233]) | 4233 | True\n", "Total parameters: 663.0, 49349138.0 elements.\n" ] } ], "source": [ "print_params(model)" ] }, { "cell_type": "code", "execution_count": 14, "id": "5ad6de2a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", "torch.Size([16, 207, 80])\n", "tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", " [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", " [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", " ...,\n", " [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", " [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", " [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", "\n", " [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", " [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", " [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", " ...,\n", " [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", " [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", " [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", "\n", " [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", " [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", " [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", " ...,\n", " [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " ...,\n", "\n", " [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", " [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", " [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", " ...,\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", " [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", " [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", " ...,\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", " [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", " [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", " ...,\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", "tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", " 166, 163], dtype=torch.int32)\n", "tensor([[2995, 3116, 1209, 565, -1, -1],\n", " [ 236, 1176, 331, 66, 3925, 4077],\n", " [2693, 524, 234, 1145, 366, -1],\n", " [3875, 4211, 3062, 700, -1, -1],\n", " [ 272, 987, 1134, 494, 2959, -1],\n", " [1936, 3715, 120, 2553, 2695, 2710],\n", " [ 25, 1149, 3930, -1, -1, -1],\n", " [1753, 1778, 1237, 482, 3925, 110],\n", " [3703, 2, 565, 3827, -1, -1],\n", " [1150, 2734, 10, 2478, 3490, -1],\n", " [ 426, 811, 95, 489, 144, -1],\n", " [2313, 2006, 489, 975, -1, -1],\n", " [3702, 3414, 205, 1488, 2966, 1347],\n", " [ 70, 1741, 702, 1666, -1, -1],\n", " [ 703, 1778, 1030, 849, -1, -1],\n", " [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", "tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)\n" ] } ], "source": [ "for batch in cv_data_loader:\n", " keys, feat, text, feat_len, text_len = batch\n", " print(keys)\n", " print(feat.shape)\n", " print(feat)\n", " print(feat_len)\n", " print(text)\n", " print(text_len)\n", " np.savez('data.npz', keys=keys, feat=feat.numpy(), feat_len=feat_len.numpy(), text=text.numpy(), text_len=text_len.numpy())\n", " break" ] }, { "cell_type": "code", "execution_count": 15, "id": "852a9c95", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CODE_OF_CONDUCT.md data.npz install.sh README.md\t tools\r\n", "CONTRIBUTING.md docs LICENSE\t requirements.txt venv\r\n", "CPPLINT.cfg\t examples Makefile\t runtime\t wenet\r\n" ] } ], "source": [ "!ls\n", "!cp data.npz /workspace/DeepSpeech-2.x/.notebook" ] }, { "cell_type": "code", "execution_count": 16, "id": "cde24c4e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(111.9988)\n", "tensor(830.9634, grad_fn=)\n", "tensor([False, False, False, False, False, True, True, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, True, True, False, False,\n", " False, False, False, False, True, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, True, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, True, True, False, False, False, False, False, False, True,\n", " False, False, False, False, False, False, True, False, False, False,\n", " False, False, True, True, False, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, False, False,\n", " False, False, False, True, True, False, False, False, False, False,\n", " True, True])\n", "tensor(669.4633, grad_fn=)\n", "tensor(142.4888, grad_fn=) tensor(41.8415, grad_fn=) tensor(377.3326, grad_fn=)\n" ] } ], "source": [ "model.cpu().eval()\n", "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", " text, text_len)\n", "print(total_loss, attention_loss, ctc_loss )" ] }, { "cell_type": "code", "execution_count": 17, "id": "be5b2a2c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "print(total_loss.device)" ] }, { "cell_type": "code", "execution_count": 18, "id": "5b791771", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(112., device='cuda:0')\n", "tensor(830.9634, device='cuda:0', grad_fn=)\n", "tensor([False, False, False, False, False, True, True, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, True, True, False, False,\n", " False, False, False, False, True, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, True, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, True, True, False, False, False, False, False, False, True,\n", " False, False, False, False, False, False, True, False, False, False,\n", " False, False, True, True, False, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, False, False,\n", " False, False, False, True, True, False, False, False, False, False,\n", " True, True], device='cuda:0')\n", "tensor(669.4634, device='cuda:0', grad_fn=)\n", "cuda:0\n", "142.4888 41.84146 377.33258\n" ] } ], "source": [ "model.cuda().eval()\n", "feat=feat.cuda()\n", "feat_len=feat_len.cuda()\n", "text=text.cuda()\n", "text_len=text_len.cuda()\n", "\n", "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", " text, text_len)\n", "print(total_loss.device)\n", "print(total_loss.cpu().data.numpy(), attention_loss.cpu().data.numpy(), ctc_loss.cpu().data.numpy() )" ] }, { "cell_type": "code", "execution_count": 19, "id": "1baef537", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([16, 51, 256])\n", "torch.Size([16, 1, 51])\n", "tensor([[-0.7019, 0.5625, 0.6880, ..., 1.1237, 0.7804, 1.1369],\n", " [-0.7788, 0.3913, 0.7189, ..., 1.2519, 0.8862, 1.3173],\n", " [-0.9591, 0.6346, 0.8767, ..., 0.9818, 0.7440, 1.2903],\n", " ...,\n", " [-1.0732, 0.6724, 0.9230, ..., 0.9075, 0.8177, 1.3240],\n", " [-1.1654, 0.6820, 0.6939, ..., 1.2238, 0.8028, 1.4507],\n", " [-1.2732, 0.7146, 0.7582, ..., 0.9415, 0.8775, 1.2623]],\n", " device='cuda:0', grad_fn=)\n" ] } ], "source": [ "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", "print(encoder_out.shape)\n", "print(encoder_mask.shape)\n", "print(encoder_out[0])\n", "\n", "np.savez('/workspace/DeepSpeech-2.x/.notebook/encoder.npz',\n", " mask=encoder_mask.cpu().detach().numpy(), \n", " out=encoder_out.cpu().detach().numpy())" ] }, { "cell_type": "code", "execution_count": null, "id": "3e22c782", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 20, "id": "30b6b946", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 9.871763 9.938915 10.238187 10.8597145 11.686526 12.25488\n", " 12.657681 12.86139 12.807339 12.566256 12.32007 12.138792\n", " 12.313189 12.552552 12.612239 12.569745 12.389728 12.143833\n", " 12.092851 11.793959 11.622591 11.926331 11.815442 11.951225\n", " 11.831805 11.887888 11.790144 11.88072 11.900057 11.973481\n", " 12.009822 12.008814 12.026197 12.104796 12.21555 12.343993\n", " 12.450144 12.496688 12.486538 12.355079 12.392918 12.255374\n", " 12.264963 12.253142 12.325458 12.4335985 12.548675 12.676334\n", " 12.809207 12.929347 12.961151 12.968834 12.995931 13.047281\n", " 13.058881 13.05738 12.999211 12.934022 12.874292 12.71653\n", " 12.48942 12.274784 12.261631 12.286319 12.31956 12.422907\n", " 12.514802 12.578516 12.647194 12.737626 12.800171 12.868728\n", " 12.966668 13.064786 13.159159 13.272843 13.310819 13.239043\n", " 12.879361 11.183102 ] float32\n", "encoder.embed.out.0.weight: (256, 4864) -> (4864, 256)\n", "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.0.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.0.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.0.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.0.conv_module.norm.running_mean -> encoder.encoders.0.conv_module.norm._mean\n", "encoder.encoders.0.conv_module.norm.running_var -> encoder.encoders.0.conv_module.norm._variance\n", "encoder.encoders.0.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.1.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.1.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.1.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.1.conv_module.norm.running_mean -> encoder.encoders.1.conv_module.norm._mean\n", "encoder.encoders.1.conv_module.norm.running_var -> encoder.encoders.1.conv_module.norm._variance\n", "encoder.encoders.1.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.2.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.2.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.2.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.2.conv_module.norm.running_mean -> encoder.encoders.2.conv_module.norm._mean\n", "encoder.encoders.2.conv_module.norm.running_var -> encoder.encoders.2.conv_module.norm._variance\n", "encoder.encoders.2.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.3.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.3.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.3.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.3.conv_module.norm.running_mean -> encoder.encoders.3.conv_module.norm._mean\n", "encoder.encoders.3.conv_module.norm.running_var -> encoder.encoders.3.conv_module.norm._variance\n", "encoder.encoders.3.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.4.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.4.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.4.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.4.conv_module.norm.running_mean -> encoder.encoders.4.conv_module.norm._mean\n", "encoder.encoders.4.conv_module.norm.running_var -> encoder.encoders.4.conv_module.norm._variance\n", "encoder.encoders.4.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.5.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.5.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.5.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.5.conv_module.norm.running_mean -> encoder.encoders.5.conv_module.norm._mean\n", "encoder.encoders.5.conv_module.norm.running_var -> encoder.encoders.5.conv_module.norm._variance\n", "encoder.encoders.5.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.6.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.6.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.6.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.6.conv_module.norm.running_mean -> encoder.encoders.6.conv_module.norm._mean\n", "encoder.encoders.6.conv_module.norm.running_var -> encoder.encoders.6.conv_module.norm._variance\n", "encoder.encoders.6.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.7.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.7.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.7.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.7.conv_module.norm.running_mean -> encoder.encoders.7.conv_module.norm._mean\n", "encoder.encoders.7.conv_module.norm.running_var -> encoder.encoders.7.conv_module.norm._variance\n", "encoder.encoders.7.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.8.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.8.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.8.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.8.conv_module.norm.running_mean -> encoder.encoders.8.conv_module.norm._mean\n", "encoder.encoders.8.conv_module.norm.running_var -> encoder.encoders.8.conv_module.norm._variance\n", "encoder.encoders.8.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.9.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.9.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.9.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.9.conv_module.norm.running_mean -> encoder.encoders.9.conv_module.norm._mean\n", "encoder.encoders.9.conv_module.norm.running_var -> encoder.encoders.9.conv_module.norm._variance\n", "encoder.encoders.9.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.10.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.10.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.10.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.10.conv_module.norm.running_mean -> encoder.encoders.10.conv_module.norm._mean\n", "encoder.encoders.10.conv_module.norm.running_var -> encoder.encoders.10.conv_module.norm._variance\n", "encoder.encoders.10.concat_linear.weight: (256, 512) -> (512, 256)\n", "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.11.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.11.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", "encoder.encoders.11.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", "encoder.encoders.11.conv_module.norm.running_mean -> encoder.encoders.11.conv_module.norm._mean\n", "encoder.encoders.11.conv_module.norm.running_var -> encoder.encoders.11.conv_module.norm._variance\n", "encoder.encoders.11.concat_linear.weight: (256, 512) -> (512, 256)\n", "decoder.output_layer.weight: (4233, 256) -> (256, 4233)\n", "decoder.decoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.0.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.0.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.0.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.0.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "decoder.decoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "decoder.decoders.0.concat_linear1.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.0.concat_linear2.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.1.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.1.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.1.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.1.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "decoder.decoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "decoder.decoders.1.concat_linear1.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.1.concat_linear2.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.2.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.2.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.2.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.2.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "decoder.decoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "decoder.decoders.2.concat_linear1.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.2.concat_linear2.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.3.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.3.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.3.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.3.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "decoder.decoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "decoder.decoders.3.concat_linear1.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.3.concat_linear2.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.4.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.4.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.4.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.4.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "decoder.decoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "decoder.decoders.4.concat_linear1.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.4.concat_linear2.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.5.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.5.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.5.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.5.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", "decoder.decoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", "decoder.decoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", "decoder.decoders.5.concat_linear1.weight: (256, 512) -> (512, 256)\n", "decoder.decoders.5.concat_linear2.weight: (256, 512) -> (512, 256)\n", "ctc.ctc_lo.weight: (4233, 256) -> (256, 4233)\n" ] } ], "source": [ "# dump torch model to paddle\n", "import numpy as np\n", "state_dict = model.state_dict()\n", "paddle_state_dict = {}\n", "\n", "for n, p in state_dict.items():\n", " name_change=True\n", "\n", " if 'norm.running_mean' in n:\n", " new_n = n.replace('norm.running_', 'norm._')\n", " elif 'norm.running_var' in n:\n", " new_n = n.replace('norm.running_var', 'norm._variance')\n", " else:\n", " name_change=False\n", " new_n = n\n", " \n", " if name_change:\n", " print(f\"{n} -> {new_n}\")\n", " \n", " p = p.cpu().detach().numpy()\n", " if n.endswith('weight') and p.ndim == 2 and 'embed.0.weight' not in n:\n", " new_p = p.T\n", " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", " else:\n", " new_p = p\n", " \n", " if 'global_cmvn.mean' in n:\n", " print(p, p.dtype)\n", " \n", " paddle_state_dict[new_n] = new_p\n", " \n", "np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", " state=paddle_state_dict)" ] }, { "cell_type": "code", "execution_count": null, "id": "7307dc5b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 21, "id": "d99b29bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(377.3326, device='cuda:0', grad_fn=)\n", "None\n", "[[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", " -5.93366381e-03 -7.26613170e-03]\n", " [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", " 2.46338220e-03 2.31891591e-03]\n", " [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", " 5.76929189e-03 7.48792710e-03]\n", " ...\n", " [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", " 1.16123557e-02 1.44716976e-02]\n", " [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", " 8.58021621e-03 1.07796099e-02]\n", " [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", " 1.60815325e-02 2.03892551e-02]]\n", "[-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", " 1.0920014e-02 1.3787906e-02]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n", " print(loss_ctc.grad)\n" ] } ], "source": [ "encoder_out_lens = encoder_mask.squeeze(1).sum(1)\n", "loss_ctc = model.ctc(encoder_out, encoder_out_lens, text, text_len)\n", "print(loss_ctc)\n", "dir(loss_ctc)\n", "loss_ctc.backward()\n", "print(loss_ctc.grad)\n", "#print(model.ctc.ctc_lo.weight.grad)\n", "print(model.ctc.ctc_lo.weight.grad.T.cpu().data.numpy())\n", "print(model.ctc.ctc_lo.bias.grad.cpu().data.numpy())" ] }, { "cell_type": "code", "execution_count": 24, "id": "49b05d6d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(112., device='cuda:0')\n", "tensor(830.9634, device='cuda:0', grad_fn=)\n", "tensor([False, False, False, False, False, True, True, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, True, True, False, False,\n", " False, False, False, False, True, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, True, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, True, True, False, False, False, False, False, False, True,\n", " False, False, False, False, False, False, True, False, False, False,\n", " False, False, True, True, False, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, False, False,\n", " False, False, False, True, True, False, False, False, False, False,\n", " True, True], device='cuda:0')\n", "tensor(669.4634, device='cuda:0', grad_fn=)\n", "tensor(41.8415, device='cuda:0', grad_fn=) 0.0\n" ] } ], "source": [ "loss_att, acc_att = model._calc_att_loss(encoder_out, encoder_mask,\n", " text, text_len)\n", "print(loss_att, acc_att)" ] }, { "cell_type": "code", "execution_count": 25, "id": "413b413f", "metadata": {}, "outputs": [], "source": [ "def pad_list(xs, pad_value: int):\n", " n_batch = len(xs)\n", " max_len = max([x.size(0) for x in xs])\n", " pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)\n", " pad = pad.fill_(pad_value)\n", " for i in range(n_batch):\n", " pad[i, :xs[i].size(0)] = xs[i]\n", "\n", " return pad\n", "\n", "def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,\n", " ignore_id: int):\n", "\n", " _sos = torch.tensor([sos],\n", " dtype=torch.long,\n", " requires_grad=False,\n", " device=ys_pad.device)\n", " _eos = torch.tensor([eos],\n", " dtype=torch.long,\n", " requires_grad=False,\n", " device=ys_pad.device)\n", " ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", " ys_in = [torch.cat([_sos, y], dim=0) for y in ys]\n", " ys_out = [torch.cat([y, _eos], dim=0) for y in ys]\n", " return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)" ] }, { "cell_type": "code", "execution_count": 26, "id": "ff0c2400", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[4232, 2995, 3116, 1209, 565, 4232, 4232],\n", " [4232, 236, 1176, 331, 66, 3925, 4077],\n", " [4232, 2693, 524, 234, 1145, 366, 4232],\n", " [4232, 3875, 4211, 3062, 700, 4232, 4232],\n", " [4232, 272, 987, 1134, 494, 2959, 4232],\n", " [4232, 1936, 3715, 120, 2553, 2695, 2710],\n", " [4232, 25, 1149, 3930, 4232, 4232, 4232],\n", " [4232, 1753, 1778, 1237, 482, 3925, 110],\n", " [4232, 3703, 2, 565, 3827, 4232, 4232],\n", " [4232, 1150, 2734, 10, 2478, 3490, 4232],\n", " [4232, 426, 811, 95, 489, 144, 4232],\n", " [4232, 2313, 2006, 489, 975, 4232, 4232],\n", " [4232, 3702, 3414, 205, 1488, 2966, 1347],\n", " [4232, 70, 1741, 702, 1666, 4232, 4232],\n", " [4232, 703, 1778, 1030, 849, 4232, 4232],\n", " [4232, 814, 1674, 115, 3827, 4232, 4232]], device='cuda:0')\n", "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", " [2693, 524, 234, 1145, 366, 4232, -1],\n", " [3875, 4211, 3062, 700, 4232, -1, -1],\n", " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", " [3703, 2, 565, 3827, 4232, -1, -1],\n", " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", " [ 426, 811, 95, 489, 144, 4232, -1],\n", " [2313, 2006, 489, 975, 4232, -1, -1],\n", " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n" ] } ], "source": [ "ys_pad = text\n", "ys_pad_lens = text_len\n", "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", " model.ignore_id)\n", "ys_in_lens = ys_pad_lens + 1\n", "print(ys_in_pad)\n", "print(ys_out_pad)" ] }, { "cell_type": "code", "execution_count": 27, "id": "3e84da38", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([16, 7, 4233])\n", "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", " 1.5035e-02, 4.0337e-01],\n", " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", " -1.4353e-01, -1.0024e+00],\n", " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", " -1.1672e+00, -2.6849e-01],\n", " ...,\n", " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", " 4.5470e-02, -3.7139e-01],\n", " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", " -7.9347e-04, 4.2538e-01],\n", " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" ] } ], "source": [ "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", " ys_in_lens)\n", "print(decoder_out.shape)\n", "print(decoder_out[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "aac441ea", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 28, "id": "5ddbca73", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.float32\n", "torch.int64\n", "tensor(112., device='cuda:0')\n", "tensor(830.9634, device='cuda:0', grad_fn=)\n", "tensor([False, False, False, False, False, True, True, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, True, True, False, False,\n", " False, False, False, False, True, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, True, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, True, True, False, False, False, False, False, False, True,\n", " False, False, False, False, False, False, True, False, False, False,\n", " False, False, True, True, False, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, False, False,\n", " False, False, False, True, True, False, False, False, False, False,\n", " True, True], device='cuda:0')\n", "tensor(669.4634, device='cuda:0', grad_fn=)\n", "tensor(41.8415, device='cuda:0', grad_fn=)\n", "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", " [2693, 524, 234, 1145, 366, 4232, -1],\n", " [3875, 4211, 3062, 700, 4232, -1, -1],\n", " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", " [3703, 2, 565, 3827, 4232, -1, -1],\n", " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", " [ 426, 811, 95, 489, 144, 4232, -1],\n", " [2313, 2006, 489, 975, 4232, -1, -1],\n", " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n", "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", " 1.5035e-02, 4.0337e-01],\n", " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", " -1.4353e-01, -1.0024e+00],\n", " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", " -1.1672e+00, -2.6849e-01],\n", " ...,\n", " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", " 4.5470e-02, -3.7139e-01],\n", " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", " -7.9347e-04, 4.2538e-01],\n", " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" ] } ], "source": [ "print(decoder_out.dtype)\n", "print(ys_out_pad.dtype)\n", "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", "print(loss_att)\n", "print(ys_out_pad)\n", "print(decoder_out[0])\n", "np.savez('/workspace/DeepSpeech-2.x/.notebook/decoder',\n", " decoder_out=decoder_out.cpu().detach().numpy())" ] }, { "cell_type": "code", "execution_count": null, "id": "78f98c0b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 29, "id": "8d968cd3", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "\n", "\n", "class LabelSmoothingLoss(nn.Module):\n", " def __init__(self,\n", " size: int,\n", " padding_idx: int,\n", " smoothing: float,\n", " normalize_length: bool = False):\n", " \"\"\"Construct an LabelSmoothingLoss object.\"\"\"\n", " super(LabelSmoothingLoss, self).__init__()\n", " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", " self.padding_idx = padding_idx\n", " self.confidence = 1.0 - smoothing\n", " self.smoothing = smoothing\n", " self.size = size\n", " self.normalize_length = normalize_length\n", "\n", " def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Compute loss between x and target.\n", "\n", " The model outputs and data labels tensors are flatten to\n", " (batch*seqlen, class) shape and a mask is applied to the\n", " padding part which should not be calculated for loss.\n", "\n", " Args:\n", " x (torch.Tensor): prediction (batch, seqlen, class)\n", " target (torch.Tensor):\n", " target signal masked with self.padding_id (batch, seqlen)\n", " Returns:\n", " loss (torch.Tensor) : The KL loss, scalar float value\n", " \"\"\"\n", " assert x.size(2) == self.size\n", " batch_size = x.size(0)\n", " x = x.view(-1, self.size)\n", " target = target.view(-1)\n", " # use zeros_like instead of torch.no_grad() for true_dist,\n", " # since no_grad() can not be exported by JIT\n", " true_dist = torch.zeros_like(x)\n", " true_dist.fill_(self.smoothing / (self.size - 1))\n", " ignore = target == self.padding_idx # (B,)\n", " print(self.smoothing / (self.size - 1))\n", " print(true_dist)\n", " total = len(target) - ignore.sum().item()\n", " target = target.masked_fill(ignore, 0) # avoid -1 index\n", " true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n", " print(true_dist.dtype)\n", " print(true_dist.square().sum())\n", " kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)\n", " print(kl.sum())\n", " denom = total if self.normalize_length else batch_size\n", " print(ignore)\n", " numer= kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", " print(numer)\n", " return numer /denom" ] }, { "cell_type": "code", "execution_count": 30, "id": "3df340ec", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.3629489603024576e-05\n", "tensor([[2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", " 2.3629e-05],\n", " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", " 2.3629e-05],\n", " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", " 2.3629e-05],\n", " ...,\n", " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", " 2.3629e-05],\n", " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", " 2.3629e-05],\n", " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", " 2.3629e-05]], device='cuda:0')\n", "torch.float32\n", "tensor(90.7203, device='cuda:0')\n", "tensor(830.9634, device='cuda:0', grad_fn=)\n", "tensor([False, False, False, False, False, True, True, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, True, True, False, False,\n", " False, False, False, False, True, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, True, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, True, True, False, False, False, False, False, False, True,\n", " False, False, False, False, False, False, True, False, False, False,\n", " False, False, True, True, False, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, False, False,\n", " False, False, False, True, True, False, False, False, False, False,\n", " True, True], device='cuda:0')\n", "tensor(669.4634, device='cuda:0', grad_fn=)\n", "tensor(41.8415, device='cuda:0', grad_fn=)\n", "torch.int64\n" ] } ], "source": [ "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", "loss_att = criteron(decoder_out, ys_out_pad)\n", "print(loss_att)\n", "print(ys_out_pad.dtype)" ] }, { "cell_type": "code", "execution_count": 31, "id": "badc410d", "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecoder_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m allow_unreachable=True) # allow_unreachable flag\n", "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time." ] } ], "source": [ "loss_att.backward()\n", "print(loss_att.grad)\n", "print(decoder_out.grad)" ] }, { "cell_type": "code", "execution_count": 32, "id": "219eb41f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.0024, 0.0019, -0.1098, ..., 0.0028, 0.0020, -1.7978],\n", " device='cuda:0')\n", "tensor([[ 6.5052e-04, 6.4419e-05, -6.1955e-06, ..., 9.8220e-04,\n", " -2.5918e-05, 3.3754e-04],\n", " [ 3.9305e-04, 4.5799e-04, 1.4362e-04, ..., 4.6800e-04,\n", " 1.6911e-04, 2.7067e-04],\n", " [-1.3593e-01, 5.2201e-02, 3.2895e-02, ..., 2.4580e-02,\n", " 1.4590e-01, -4.6850e-02],\n", " ...,\n", " [ 1.0434e-03, 4.2251e-04, 6.5688e-04, ..., 1.2144e-03,\n", " 2.1159e-04, 6.6838e-04],\n", " [ 6.4997e-04, 4.4301e-04, 4.1550e-04, ..., 1.0420e-03,\n", " 2.4114e-04, 1.5338e-04],\n", " [-9.9337e-01, 5.4573e-01, -1.1371e-02, ..., -4.3175e-01,\n", " -2.7850e-01, -4.4679e-01]], device='cuda:0')\n" ] } ], "source": [ "print(model.decoder.output_layer.bias.grad)\n", "print(model.decoder.output_layer.weight.grad)" ] }, { "cell_type": "code", "execution_count": 42, "id": "40d00a54", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[-5.3698e-01, -1.9911e-01, -3.4997e-01, ..., -8.2428e-01,\n", " -1.0265e+00, -9.6301e-01],\n", " [-4.4642e-02, 2.3176e-01, -3.2539e-01, ..., -9.0159e-01,\n", " -1.0325e+00, -7.5987e-01],\n", " [ 5.0035e-01, 2.2691e-01, -7.3052e-01, ..., -1.0055e+00,\n", " -8.7123e-01, -1.0306e+00],\n", " ...,\n", " [-4.0024e-01, -1.4325e-01, -5.7947e-01, ..., -1.0718e+00,\n", " -1.2806e+00, -1.0518e+00],\n", " [ 1.5755e-01, -1.8495e-03, -2.8703e-01, ..., -1.1090e+00,\n", " -9.4519e-01, -7.2506e-01],\n", " [-4.7520e-01, -1.3942e+00, -2.5754e-01, ..., -1.1365e+00,\n", " -1.1943e+00, -1.2290e+00]],\n", "\n", " [[ 9.5454e-01, 3.6428e-01, -1.3891e+00, ..., -1.1637e+00,\n", " -1.2845e+00, -1.2015e+00],\n", " [-8.5735e-02, -1.0579e+00, -8.9173e-01, ..., -9.6441e-01,\n", " -1.1255e+00, -1.2599e+00],\n", " [ 4.7654e-01, 3.2887e-01, -5.9201e-01, ..., -1.1942e+00,\n", " -1.1430e+00, -1.0242e+00],\n", " ...,\n", " [-4.7431e-01, -3.3559e-01, -7.2326e-01, ..., -1.4506e+00,\n", " -1.3957e+00, -1.0464e+00],\n", " [ 3.6113e-01, 1.0381e-01, -1.1599e+00, ..., -1.0439e+00,\n", " -1.0221e+00, -1.0208e+00],\n", " [-1.2717e+00, -2.1460e+00, -7.5677e-01, ..., -9.7822e-01,\n", " -9.3785e-01, -1.0371e+00]],\n", "\n", " [[-1.5465e+00, -1.0152e+00, -8.8901e-01, ..., -4.8522e-01,\n", " -7.5163e-01, -6.7765e-01],\n", " [-7.6101e-01, -7.3352e-01, -9.1588e-01, ..., -2.4836e-01,\n", " -5.8927e-01, -7.3723e-01],\n", " [-2.4714e-02, 1.7016e-01, -4.2326e-01, ..., -3.3204e-01,\n", " -7.6696e-01, -7.1652e-01],\n", " ...,\n", " [-1.7032e+00, -1.2591e+00, -1.1449e+00, ..., -1.1810e+00,\n", " -1.1163e+00, -9.3108e-01],\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00],\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00]],\n", "\n", " ...,\n", "\n", " [[ 6.4983e-01, 2.6117e-01, -8.4197e-01, ..., -8.7213e-01,\n", " -1.1073e+00, -1.3253e+00],\n", " [ 3.5391e-01, -1.5846e-02, -4.0425e-01, ..., -9.9173e-01,\n", " -1.0727e+00, -1.1924e+00],\n", " [ 3.7704e-01, -6.2785e-02, -1.1468e-01, ..., -1.1021e+00,\n", " -1.0952e+00, -1.1182e+00],\n", " ...,\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00],\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00],\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00]],\n", "\n", " [[ 4.4458e-02, -1.7547e-01, -6.7475e-01, ..., -4.9801e-01,\n", " -5.6783e-01, -7.7852e-01],\n", " [-1.3428e+00, -8.0343e-01, -9.0457e-01, ..., -6.5902e-01,\n", " -7.2550e-01, -6.2796e-01],\n", " [-7.6253e-01, -1.3071e-01, -1.3280e-01, ..., -5.6133e-01,\n", " -6.0588e-01, -7.2115e-01],\n", " ...,\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00],\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00],\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00]],\n", "\n", " [[-1.0798e+00, -1.0834e+00, -1.1797e+00, ..., -1.7757e-01,\n", " -4.3747e-01, -4.0007e-02],\n", " [ 9.2354e-01, 6.3771e-01, -5.2810e-01, ..., -1.2928e-01,\n", " -2.0342e-01, 1.6656e-01],\n", " [ 4.9337e-01, -9.1133e-03, -7.3302e-01, ..., 1.0074e-01,\n", " -9.8115e-02, -9.2357e-03],\n", " ...,\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00],\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00],\n", " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", " -3.9869e+00, -3.6797e+00]]], device='cuda:0')\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "print(xs)" ] }, { "cell_type": "code", "execution_count": 43, "id": "505ca294", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[ True, True, True, ..., True, True, True]],\n", "\n", " [[ True, True, True, ..., True, True, True]],\n", "\n", " [[ True, True, True, ..., True, False, False]],\n", "\n", " ...,\n", "\n", " [[ True, True, True, ..., False, False, False]],\n", "\n", " [[ True, True, True, ..., False, False, False]],\n", "\n", " [[ True, True, True, ..., False, False, False]]], device='cuda:0')\n" ] } ], "source": [ "from wenet.utils.mask import make_pad_mask\n", "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", "print(masks)" ] }, { "cell_type": "code", "execution_count": 44, "id": "aa03c2b9", "metadata": {}, "outputs": [], "source": [ "xs, pos_emb, masks = model.encoder.embed(xs, masks)" ] }, { "cell_type": "code", "execution_count": 45, "id": "ebc0ea12", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[-0.5482, 2.2866, -1.0750, ..., 1.4504, 0.2895, -0.6945],\n", " [-0.8013, 1.7688, -1.6639, ..., 1.8332, 0.6791, -0.2000],\n", " [-1.7112, 2.7057, -1.3363, ..., 1.2336, 0.1870, -0.5735],\n", " ...,\n", " [-0.9697, 2.3129, -0.8752, ..., 0.8584, 0.4853, -0.4177],\n", " [-1.3609, 2.1779, -1.7813, ..., 2.0928, 0.2528, -0.3650],\n", " [-1.6967, 2.3544, -1.7417, ..., 1.3670, 0.5951, -0.7415]],\n", "\n", " [[-1.9828, 2.3178, -0.9079, ..., 0.4117, 0.5006, 0.0872],\n", " [-0.7640, 1.3558, -1.3613, ..., 0.7317, 0.6784, 0.1685],\n", " [-0.9504, 1.6038, -1.3030, ..., 0.5754, 0.2677, 0.3343],\n", " ...,\n", " [-1.4757, 2.5317, -1.2321, ..., 1.2997, 0.5019, -0.1034],\n", " [-1.1731, 2.3172, -1.2542, ..., 1.7391, 0.2171, -0.4445],\n", " [-1.2700, 3.2229, -0.8872, ..., 1.6461, 0.0973, -0.7679]],\n", "\n", " [[-0.5873, 1.4291, -1.3950, ..., 0.2102, 0.1027, 0.0918],\n", " [ 0.1743, 1.7834, -1.6422, ..., 0.8113, 0.3137, 0.5634],\n", " [-0.3492, 1.8310, -1.0685, ..., 0.6924, 0.1378, 0.4594],\n", " ...,\n", " [-1.0869, 2.3002, -1.2638, ..., 1.7998, 0.5134, -0.5223],\n", " [-1.2614, 2.7240, -1.3734, ..., 1.4445, 0.5742, -0.3320],\n", " [-2.2068, 4.3462, -3.8289, ..., 2.1426, 1.2034, -1.3795]],\n", "\n", " ...,\n", "\n", " [[-0.3914, 1.8553, -0.5747, ..., 1.0062, 0.4632, -1.0452],\n", " [-0.8605, 2.0172, -1.4437, ..., 1.4526, 0.1657, 0.5923],\n", " [-0.7307, 2.2841, -1.0699, ..., 1.5825, -0.0980, 0.5503],\n", " ...,\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", "\n", " [[-0.1619, 0.6255, -1.1323, ..., 0.0724, -0.2204, 0.4636],\n", " [-0.0831, 0.5750, -1.0930, ..., 0.9110, -0.0650, 0.7299],\n", " [-0.2820, 0.0801, -0.9418, ..., 0.3379, -0.1166, 0.4451],\n", " ...,\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", "\n", " [[-0.5458, -0.6909, -1.3597, ..., -0.7818, 0.6875, 0.9843],\n", " [ 0.0421, -1.1062, -1.4389, ..., -0.0239, 0.9115, 0.5287],\n", " [-0.2909, -0.1886, -1.5487, ..., -0.1392, 0.0580, 0.3066],\n", " ...,\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]]],\n", " device='cuda:0', grad_fn=)\n", "tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n", " 0.0000e+00, 1.0000e+00],\n", " [ 8.4147e-01, 5.4030e-01, 8.0196e-01, ..., 1.0000e+00,\n", " 1.0746e-04, 1.0000e+00],\n", " [ 9.0930e-01, -4.1615e-01, 9.5814e-01, ..., 1.0000e+00,\n", " 2.1492e-04, 1.0000e+00],\n", " ...,\n", " [-7.6825e-01, -6.4014e-01, 6.3280e-01, ..., 9.9998e-01,\n", " 5.1581e-03, 9.9999e-01],\n", " [-9.5375e-01, 3.0059e-01, 9.9899e-01, ..., 9.9998e-01,\n", " 5.2656e-03, 9.9999e-01],\n", " [-2.6237e-01, 9.6497e-01, 5.6075e-01, ..., 9.9998e-01,\n", " 5.3730e-03, 9.9999e-01]]], device='cuda:0')\n", "tensor([[[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, False,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, False, False, False,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, False, False, False,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, False, False, False,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, False, False, False,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, False, False, False, False, False,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, False, False, False, False, False, False, False, False,\n", " False]],\n", "\n", " [[ True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, False, False, False, False, False, False, False, False, False,\n", " False]]], device='cuda:0')\n", "torch.Size([16, 1, 51])\n" ] } ], "source": [ "print(xs)\n", "print(pos_emb)\n", "print(masks)\n", "print(masks.shape)" ] }, { "cell_type": "code", "execution_count": 46, "id": "4289461b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", " -0.6945408 ]\n", " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", " -0.1999542 ]\n", " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", " -0.5735198 ]\n", " ...\n", " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", " -0.41773027]\n", " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", " -0.36496443]\n", " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", " -0.74147725]]\n", "\n", " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", " 0.08721463]\n", " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", " 0.16851945]\n", " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", " 0.33433008]\n", " ...\n", " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", " -0.10343577]\n", " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", " -0.44447583]\n", " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", " -0.7678688 ]]\n", "\n", " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", " 0.09179455]\n", " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", " 0.56344515]\n", " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", " 0.45937473]\n", " ...\n", " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", " -0.52227837]\n", " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", " -0.33201432]\n", " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", " -1.3795122 ]]\n", "\n", " ...\n", "\n", " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", " -1.045236 ]\n", " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", " 0.5923172 ]\n", " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", " 0.55030036]\n", " ...\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]]\n", "\n", " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", " 0.46362036]\n", " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", " 0.72986233]\n", " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", " 0.44514441]\n", " ...\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]]\n", "\n", " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", " 0.9842716 ]\n", " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", " 0.52870303]\n", " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", " 0.30663735]\n", " ...\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]]]\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", "print(xs.cpu().detach().numpy())" ] }, { "cell_type": "code", "execution_count": 47, "id": "67e10d73", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 2.0908e-03],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 1.1943e-02, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 4.6105e-02, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 9.6723e-03,\n", " 4.6135e-02, 0.0000e+00]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " ...,\n", "\n", " [[2.2816e-01, 2.4615e-01, 2.5304e-01, ..., 2.0402e-01,\n", " 2.3248e-01, 3.1191e-01],\n", " [1.3587e-01, 2.8877e-01, 2.7991e-01, ..., 1.9210e-01,\n", " 2.0346e-01, 1.9934e-01],\n", " [2.5739e-01, 3.9348e-01, 2.7877e-01, ..., 2.7483e-01,\n", " 1.9302e-01, 2.3810e-01],\n", " ...,\n", " [1.1939e-01, 2.8473e-01, 3.3082e-01, ..., 2.3838e-01,\n", " 2.2104e-01, 2.3906e-01],\n", " [1.7388e-01, 2.0402e-01, 4.0263e-01, ..., 2.4782e-01,\n", " 2.6742e-01, 1.5427e-01],\n", " [0.0000e+00, 2.9081e-01, 2.7726e-01, ..., 1.7540e-01,\n", " 1.8479e-01, 2.2483e-01]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[3.5447e-01, 3.8861e-01, 3.9724e-01, ..., 3.8680e-01,\n", " 3.3568e-01, 3.4552e-01],\n", " [4.1739e-01, 5.1039e-01, 4.1730e-01, ..., 3.3993e-01,\n", " 3.7082e-01, 3.5110e-01],\n", " [3.6117e-01, 4.0745e-01, 4.8491e-01, ..., 3.4849e-01,\n", " 3.2321e-01, 3.5189e-01],\n", " ...,\n", " [2.3144e-01, 3.8021e-01, 5.1526e-01, ..., 3.6499e-01,\n", " 3.7412e-01, 3.9986e-01],\n", " [3.4679e-01, 4.0238e-01, 5.0077e-01, ..., 3.6185e-01,\n", " 3.1597e-01, 3.6335e-01],\n", " [3.6498e-01, 3.7943e-01, 5.1719e-01, ..., 3.1798e-01,\n", " 3.3657e-01, 3.4130e-01]]],\n", "\n", "\n", " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[1.4560e-02, 9.4475e-02, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [1.5002e-02, 2.9632e-02, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [3.2952e-02, 0.0000e+00, 0.0000e+00, ..., 4.5850e-02,\n", " 2.0439e-02, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 4.4258e-02],\n", " [0.0000e+00, 0.0000e+00, 2.5565e-02, ..., 0.0000e+00,\n", " 9.0044e-03, 4.9084e-02]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [1.1141e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " ...,\n", "\n", " [[3.3697e-01, 3.8527e-01, 3.2900e-01, ..., 2.8704e-01,\n", " 2.3351e-01, 1.9004e-01],\n", " [1.3575e-01, 3.5783e-01, 3.3573e-01, ..., 2.2082e-01,\n", " 1.5855e-01, 1.3587e-01],\n", " [2.1929e-01, 2.8900e-01, 2.8255e-01, ..., 2.0603e-01,\n", " 2.3927e-01, 2.1909e-01],\n", " ...,\n", " [2.3292e-01, 3.9097e-01, 3.6399e-01, ..., 2.0598e-01,\n", " 2.5374e-01, 2.3137e-01],\n", " [1.8739e-01, 3.0794e-01, 3.0297e-01, ..., 2.7251e-01,\n", " 2.5192e-01, 2.0837e-01],\n", " [2.2454e-01, 4.1402e-01, 5.4083e-01, ..., 3.1875e-01,\n", " 2.5080e-01, 2.5939e-01]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[2.6457e-01, 4.9519e-01, 5.6702e-01, ..., 3.0955e-01,\n", " 3.5292e-01, 3.2669e-01],\n", " [2.1577e-01, 5.1833e-01, 4.9183e-01, ..., 3.6043e-01,\n", " 3.8524e-01, 3.6155e-01],\n", " [2.0068e-01, 4.2784e-01, 5.2818e-01, ..., 3.1871e-01,\n", " 3.2452e-01, 3.1036e-01],\n", " ...,\n", " [4.9855e-01, 5.1001e-01, 5.2279e-01, ..., 3.6450e-01,\n", " 3.4338e-01, 3.3603e-01],\n", " [4.1233e-01, 5.5518e-01, 5.2828e-01, ..., 4.0676e-01,\n", " 3.3873e-01, 3.6724e-01],\n", " [4.0820e-01, 4.6187e-01, 4.7338e-01, ..., 3.8691e-01,\n", " 3.6039e-01, 3.8022e-01]]],\n", "\n", "\n", " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[0.0000e+00, 5.7852e-03, 0.0000e+00, ..., 7.4838e-03,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0351e-02,\n", " 0.0000e+00, 2.6720e-04],\n", " [9.4807e-04, 0.0000e+00, 0.0000e+00, ..., 7.9551e-03,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [2.0326e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 1.0801e-02, 0.0000e+00],\n", " [1.8470e-01, 0.0000e+00, 0.0000e+00, ..., 5.0584e-02,\n", " 9.4758e-02, 5.9146e-02]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " ...,\n", "\n", " [[3.8708e-01, 2.8022e-01, 3.5893e-01, ..., 1.6595e-01,\n", " 1.6031e-01, 2.1136e-01],\n", " [1.5595e-01, 3.0544e-01, 2.4666e-01, ..., 2.2675e-01,\n", " 2.5765e-01, 1.9682e-01],\n", " [2.9518e-01, 4.1210e-01, 2.0063e-01, ..., 1.7595e-01,\n", " 2.2537e-01, 2.2214e-01],\n", " ...,\n", " [2.4745e-01, 2.6259e-01, 3.8654e-01, ..., 2.3620e-01,\n", " 2.3157e-01, 1.8514e-01],\n", " [2.5715e-01, 2.9593e-01, 4.7745e-01, ..., 2.3546e-01,\n", " 2.5073e-01, 2.0976e-01],\n", " [1.2015e+00, 8.4644e-01, 7.3386e-01, ..., 1.0252e+00,\n", " 9.5310e-01, 1.0013e+00]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[4.5013e-01, 4.7484e-01, 4.0540e-01, ..., 1.9346e-01,\n", " 1.7826e-01, 1.4777e-01],\n", " [4.7546e-01, 4.8187e-01, 3.6760e-01, ..., 2.7809e-01,\n", " 3.2997e-01, 3.2337e-01],\n", " [4.6160e-01, 4.0050e-01, 3.9061e-01, ..., 3.6613e-01,\n", " 3.5243e-01, 2.9739e-01],\n", " ...,\n", " [5.5148e-01, 5.1018e-01, 4.0132e-01, ..., 3.8948e-01,\n", " 3.5737e-01, 3.3088e-01],\n", " [4.1973e-01, 4.5475e-01, 4.5320e-01, ..., 3.8343e-01,\n", " 4.0126e-01, 3.6181e-01],\n", " [3.4280e-01, 3.1606e-01, 4.4701e-01, ..., 2.1665e-01,\n", " 2.3985e-01, 2.3903e-01]]],\n", "\n", "\n", " ...,\n", "\n", "\n", " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [4.1783e-02, 0.0000e+00, 1.5805e-02, ..., 0.0000e+00,\n", " 2.2508e-02, 0.0000e+00],\n", " [4.3234e-02, 7.7864e-02, 0.0000e+00, ..., 1.6347e-02,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[3.2092e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [1.3563e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " ...,\n", "\n", " [[0.0000e+00, 2.5187e-01, 2.4979e-01, ..., 2.4775e-01,\n", " 2.2354e-01, 1.9149e-01],\n", " [1.6541e-01, 1.9586e-01, 1.9813e-01, ..., 2.7344e-01,\n", " 2.0928e-01, 2.6150e-01],\n", " [1.0495e-01, 6.3299e-02, 3.3844e-01, ..., 2.5138e-01,\n", " 1.2470e-01, 2.3927e-01],\n", " ...,\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00],\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00],\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[1.1428e-01, 4.5667e-01, 4.6821e-01, ..., 3.2058e-01,\n", " 3.3579e-01, 3.9013e-01],\n", " [1.0441e-01, 4.5739e-01, 4.6107e-01, ..., 3.8468e-01,\n", " 3.8291e-01, 3.6686e-01],\n", " [1.9868e-01, 3.5520e-01, 4.4313e-01, ..., 4.0679e-01,\n", " 3.8068e-01, 3.0646e-01],\n", " ...,\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00],\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00],\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00]]],\n", "\n", "\n", " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[2.4654e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 3.3902e-02],\n", " [0.0000e+00, 0.0000e+00, 1.8307e-02, ..., 5.1669e-02,\n", " 9.4838e-03, 7.4535e-02],\n", " [9.9215e-02, 0.0000e+00, 1.5872e-02, ..., 1.6203e-02,\n", " 5.1401e-02, 1.9239e-03],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " ...,\n", "\n", " [[4.0034e-01, 2.5306e-01, 2.0218e-01, ..., 9.8162e-02,\n", " 7.0643e-02, 4.9741e-02],\n", " [1.2568e-01, 2.1031e-01, 1.1182e-01, ..., 4.2781e-02,\n", " 1.1969e-01, 1.2005e-01],\n", " [2.8787e-01, 2.4031e-01, 2.2566e-01, ..., 0.0000e+00,\n", " 6.4181e-02, 5.8730e-02],\n", " ...,\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00],\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00],\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[3.8405e-01, 3.0990e-01, 3.7156e-01, ..., 1.8125e-01,\n", " 1.5051e-01, 1.9620e-01],\n", " [4.7286e-01, 4.0529e-01, 3.9718e-01, ..., 2.4710e-01,\n", " 4.5657e-02, 1.1501e-01],\n", " [3.2621e-01, 3.0073e-01, 3.0477e-01, ..., 2.3529e-01,\n", " 2.1357e-01, 1.6986e-01],\n", " ...,\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00],\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00],\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00]]],\n", "\n", "\n", " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[3.3438e-02, 1.2378e-03, 5.2972e-02, ..., 7.2712e-02,\n", " 8.6563e-02, 1.4494e-01],\n", " [1.1043e-01, 6.1431e-02, 6.3630e-02, ..., 8.1278e-02,\n", " 6.2590e-02, 8.3154e-02],\n", " [1.7677e-02, 2.0111e-03, 7.8750e-02, ..., 6.9633e-02,\n", " 8.9799e-02, 5.3263e-02],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[1.0034e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [1.5627e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [5.1447e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 4.3641e-03],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " ...,\n", "\n", " [[2.5142e-01, 4.5964e-01, 3.7346e-01, ..., 4.7631e-02,\n", " 0.0000e+00, 0.0000e+00],\n", " [1.9760e-01, 2.6627e-01, 1.1191e-01, ..., 3.0450e-02,\n", " 0.0000e+00, 0.0000e+00],\n", " [1.6341e-01, 3.2938e-01, 2.5690e-01, ..., 5.5694e-02,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00],\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00],\n", " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", " 1.0094e+00, 1.0221e+00]],\n", "\n", " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 2.2189e-02, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 2.8490e-02],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " ...,\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[2.5810e-01, 6.3017e-01, 3.7038e-01, ..., 1.8704e-01,\n", " 8.2694e-02, 9.9127e-02],\n", " [1.7293e-01, 5.0679e-01, 4.0739e-01, ..., 1.6006e-01,\n", " 1.1725e-01, 9.9405e-02],\n", " [2.4175e-01, 4.1616e-01, 4.1257e-01, ..., 1.3520e-01,\n", " 7.9126e-02, 1.2846e-01],\n", " ...,\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00],\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00],\n", " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", " 1.2189e+00, 1.2380e+00]]]], device='cuda:0',\n", " grad_fn=)\n" ] } ], "source": [ "xs = model.encoder.global_cmvn(feat)\n", "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", "\n", "x = xs.unsqueeze(1)\n", "x = model.encoder.embed.conv(x)\n", "print(x)" ] }, { "cell_type": "code", "execution_count": 48, "id": "9a9478ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[[-0.03426375 0.14291267 -0.06718873 ... 0.09064753 0.01809387\n", " -0.0434088 ]\n", " [-0.05007839 0.11054724 -0.10399298 ... 0.11457238 0.04244684\n", " -0.01249714]\n", " [-0.10695291 0.16910909 -0.08352133 ... 0.07710276 0.01168563\n", " -0.03584499]\n", " ...\n", " [-0.06060536 0.14455931 -0.05470302 ... 0.05364908 0.03033342\n", " -0.02610814]\n", " [-0.08505894 0.13611752 -0.11132983 ... 0.13079923 0.01580139\n", " -0.02281028]\n", " [-0.10604677 0.14714901 -0.10885533 ... 0.08543444 0.03719445\n", " -0.04634233]]\n", "\n", " [[-0.12392755 0.14486063 -0.05674079 ... 0.02573164 0.03128851\n", " 0.00545091]\n", " [-0.04775286 0.08473608 -0.08507854 ... 0.04573154 0.04240163\n", " 0.01053247]\n", " [-0.05940291 0.10023535 -0.0814373 ... 0.035965 0.01673085\n", " 0.02089563]\n", " ...\n", " [-0.09222981 0.15823206 -0.07700447 ... 0.08122957 0.03136991\n", " -0.00646474]\n", " [-0.07331756 0.14482647 -0.07838815 ... 0.1086944 0.01356864\n", " -0.02777974]\n", " [-0.07937264 0.20143102 -0.05544947 ... 0.10287814 0.00608235\n", " -0.0479918 ]]\n", "\n", " [[-0.03670349 0.0893159 -0.08718812 ... 0.0131405 0.00642052\n", " 0.00573716]\n", " [ 0.01089254 0.11146393 -0.10263617 ... 0.05070438 0.01960694\n", " 0.03521532]\n", " [-0.0218228 0.11443964 -0.06678198 ... 0.04327708 0.00861394\n", " 0.02871092]\n", " ...\n", " [-0.06792898 0.14376275 -0.07899005 ... 0.11248926 0.03208683\n", " -0.0326424 ]\n", " [-0.07884051 0.17024788 -0.08583611 ... 0.09028331 0.03588808\n", " -0.0207509 ]\n", " [-0.13792302 0.27163863 -0.23930418 ... 0.13391261 0.0752104\n", " -0.08621951]]\n", "\n", " ...\n", "\n", " [[-0.02446348 0.11595841 -0.03591986 ... 0.0628897 0.02895011\n", " -0.06532725]\n", " [-0.05378424 0.1260737 -0.09023033 ... 0.09078894 0.01035743\n", " 0.03701983]\n", " [-0.04566649 0.14275314 -0.0668687 ... 0.09890588 -0.00612222\n", " 0.03439377]\n", " ...\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]]\n", "\n", " [[-0.01012144 0.03909408 -0.07077143 ... 0.00452683 -0.01377654\n", " 0.02897627]\n", " [-0.00519154 0.03594019 -0.06831125 ... 0.05693541 -0.00406374\n", " 0.0456164 ]\n", " [-0.01762631 0.00500899 -0.05886075 ... 0.02112178 -0.00729015\n", " 0.02782153]\n", " ...\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]]\n", "\n", " [[-0.03411558 -0.04318277 -0.08497842 ... -0.04886402 0.04296734\n", " 0.06151697]\n", " [ 0.00263296 -0.06913657 -0.08993219 ... -0.00149064 0.05696633\n", " 0.03304394]\n", " [-0.01818341 -0.0117864 -0.09679577 ... -0.00870231 0.00362198\n", " 0.01916483]\n", " ...\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]\n", " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", " -0.18293698]]]\n", "torch.Size([16, 51, 256])\n" ] } ], "source": [ "b, c, t, f = x.size()\n", "x = model.encoder.embed.out(x.transpose(1, 2).contiguous().view(b, t, c * f))\n", "print(x.cpu().detach().numpy())\n", "print(x.shape)" ] }, { "cell_type": "code", "execution_count": 49, "id": "fd69003f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", " -0.6945408 ]\n", " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", " -0.1999542 ]\n", " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", " -0.5735198 ]\n", " ...\n", " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", " -0.41773027]\n", " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", " -0.36496443]\n", " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", " -0.74147725]]\n", "\n", " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", " 0.08721463]\n", " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", " 0.16851945]\n", " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", " 0.33433008]\n", " ...\n", " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", " -0.10343577]\n", " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", " -0.44447583]\n", " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", " -0.7678688 ]]\n", "\n", " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", " 0.09179455]\n", " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", " 0.56344515]\n", " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", " 0.45937473]\n", " ...\n", " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", " -0.52227837]\n", " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", " -0.33201432]\n", " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", " -1.3795122 ]]\n", "\n", " ...\n", "\n", " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", " -1.045236 ]\n", " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", " 0.5923172 ]\n", " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", " 0.55030036]\n", " ...\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]]\n", "\n", " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", " 0.46362036]\n", " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", " 0.72986233]\n", " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", " 0.44514441]\n", " ...\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]]\n", "\n", " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", " 0.9842716 ]\n", " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", " 0.52870303]\n", " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", " 0.30663735]\n", " ...\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]\n", " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", " -2.9269917 ]]]\n" ] } ], "source": [ "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", "print(x.cpu().detach().numpy())" ] }, { "cell_type": "code", "execution_count": 50, "id": "8ed88489", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.float32\n", "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", " 0.0000000e+00 1.0000000e+00]\n", " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", " 1.0746076e-04 1.0000000e+00]\n", " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", " 2.1492151e-04 1.0000000e+00]\n", " ...\n", " [-7.6825464e-01 -6.4014435e-01 6.3279724e-01 ... 9.9998462e-01\n", " 5.1580933e-03 9.9998671e-01]\n", " [-9.5375264e-01 3.0059254e-01 9.9899054e-01 ... 9.9998397e-01\n", " 5.2655530e-03 9.9998611e-01]\n", " [-2.6237485e-01 9.6496606e-01 5.6074661e-01 ... 9.9998331e-01\n", " 5.3730118e-03 9.9998558e-01]]]\n" ] } ], "source": [ "print(pos_emb.dtype)\n", "print(pos_emb.cpu().detach().numpy())" ] }, { "cell_type": "code", "execution_count": 54, "id": "5e277881", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([16, 51, 256])\n" ] }, { "ename": "NameError", "evalue": "name 'mask' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mpos_emb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpos_emb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0mx_att\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m )\n", "\u001b[0;31mNameError\u001b[0m: name 'mask' is not defined" ] } ], "source": [ "def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,\n", " use_dynamic_chunk: bool,\n", " use_dynamic_left_chunk: bool,\n", " decoding_chunk_size: int, static_chunk_size: int,\n", " num_decoding_left_chunks: int):\n", " \"\"\" Apply optional mask for encoder.\n", " Args:\n", " xs (torch.Tensor): padded input, (B, L, D), L for max length\n", " mask (torch.Tensor): mask for xs, (B, 1, L)\n", " use_dynamic_chunk (bool): whether to use dynamic chunk or not\n", " use_dynamic_left_chunk (bool): whether to use dynamic left chunk for\n", " training.\n", " decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's\n", " 0: default for training, use random dynamic chunk.\n", " <0: for decoding, use full chunk.\n", " >0: for decoding, use fixed chunk size as set.\n", " static_chunk_size (int): chunk size for static chunk training/decoding\n", " if it's greater than 0, if use_dynamic_chunk is true,\n", " this parameter will be ignored\n", " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", " the chunk size is decoding_chunk_size.\n", " >=0: use num_decoding_left_chunks\n", " <0: use all left chunks\n", " Returns:\n", " torch.Tensor: chunk mask of the input xs.\n", " \"\"\"\n", " # Whether to use chunk mask or not\n", " if use_dynamic_chunk:\n", " max_len = xs.size(1)\n", " if decoding_chunk_size < 0:\n", " chunk_size = max_len\n", " num_left_chunks = -1\n", " elif decoding_chunk_size > 0:\n", " chunk_size = decoding_chunk_size\n", " num_left_chunks = num_decoding_left_chunks\n", " else:\n", " # chunk size is either [1, 25] or full context(max_len).\n", " # Since we use 4 times subsampling and allow up to 1s(100 frames)\n", " # delay, the maximum frame is 100 / 4 = 25.\n", " chunk_size = torch.randint(1, max_len, (1, )).item()\n", " num_left_chunks = -1\n", " if chunk_size > max_len // 2:\n", " chunk_size = max_len\n", " else:\n", " chunk_size = chunk_size % 25 + 1\n", " if use_dynamic_left_chunk:\n", " max_left_chunks = (max_len - 1) // chunk_size\n", " num_left_chunks = torch.randint(0, max_left_chunks,\n", " (1, )).item()\n", " chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,\n", " num_left_chunks,\n", " xs.device) # (L, L)\n", " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", " chunk_masks = masks & chunk_masks # (B, L, L)\n", " elif static_chunk_size > 0:\n", " num_left_chunks = num_decoding_left_chunks\n", " chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,\n", " num_left_chunks,\n", " xs.device) # (L, L)\n", " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", " chunk_masks = masks & chunk_masks # (B, L, L)\n", " else:\n", " chunk_masks = masks\n", " return chunk_masks\n", "\n", "from wenet.utils.mask import make_pad_mask\n", "\n", "\n", "masks = ~make_pad_mask(feat_len).unsqueeze(1)\n", "xs = model.encoder.global_cmvn(feat)\n", "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", "\n", "mask_pad = masks\n", "decoding_chunk_size=0\n", "num_decoding_left_chunks=-1\n", "use_dynamic_left_chunk=-1\n", "use_dynamic_chunk=False\n", "static_chunk_size=-1\n", "chunk_masks = add_optional_chunk_mask(\n", " xs, \n", " masks, \n", " use_dynamic_chunk,\n", " use_dynamic_left_chunk,\n", " decoding_chunk_size, \n", " static_chunk_size,\n", " num_decoding_left_chunks)\n", "\n", "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_embed', \n", " embed_out=xs.cpu().detach().numpy(), \n", " pos_emb=pos_emb.cpu().detach().numpy(),\n", " chunk_masks=chunk_masks.cpu().detach().numpy(),\n", " mask_pad=mask_pad.cpu().detach().numpy())\n", "\n", "model.eval()\n", "# print(chunk_masks)\n", "print(xs.shape)\n", "for layer in model.encoder.encoders:\n", " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", " #np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0', enc_0=xs.cpu().detach().numpy())\n", " \n", " x = xs\n", " residual = x\n", " x_norm = layer.norm_ff_macaron(x)\n", " !rm /workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff.npz\n", " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff', \n", " norm_ff=x_norm.cpu().detach().numpy(),\n", " xs=xs.cpu().detach().numpy())\n", " #print(x.cpu().detach().numpy())\n", " for p in layer.norm_ff_macaron.parameters():\n", " #print(p, p.sum())\n", " pass\n", " \n", " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_norm)\n", " \n", " ps = []\n", " for n, p in layer.feed_forward_macaron.state_dict().items():\n", " #print(n, p.cpu().data.numpy())\n", " ps.append(p.cpu().data.numpy())\n", " pass\n", "\n", " ff_l_x = layer.feed_forward_macaron.w_1(x_norm)\n", " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_ff_out', \n", " norm_ff=x_norm.cpu().detach().numpy(),\n", " ff_out=x.cpu().detach().numpy(),\n", " ff_l_x = ff_l_x.cpu().detach().numpy(),\n", " ff_l_a_x=ff_l_a_x.cpu().detach().numpy(),\n", " ff_l_a_l_x=ff_l_a_l_x.cpu().detach().numpy(),\n", " ps=ps,\n", " )\n", " \n", " \n", " residual = x\n", " x = layer.norm_mha(x)\n", " x_q = x\n", " \n", " x_att = layer.self_attn(x_q, x, x, pos_emb, masks)\n", " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_selattn_out', \n", " x_q=x_q.cpu().detach().numpy(),\n", " x=x.cpu().detach().numpy(),\n", " pos_emb = pos_emb.cpu().detach().numpy(),\n", " mask=mask.cpu().detach().numpy(),\n", " x_att=x_att.cpu().detach().numpy(),\n", " )\n", " \n", " break\n", "#print(xs.cpu().detach().numpy())\n", "\n", "\n", "i = 0\n", "for layer in model.encoder.encoders:\n", " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", " i += 1\n", " if i == 2:\n", " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_2', enc_2=xs.cpu().detach().numpy())\n", " \n", "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_all', enc_all=xs.cpu().detach().numpy())" ] }, { "cell_type": "code", "execution_count": null, "id": "c43fd4f1", "metadata": {}, "outputs": [], "source": [ "out, mask = model.encoder(feat, feat_len)\n", "#print(out.cpu().detach().numpy())" ] }, { "cell_type": "code", "execution_count": null, "id": "0e73db22", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "8f506114", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 5 }