You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/.notebook/train_test.ipynb

1887 lines
98 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cloudy-glass",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISISBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "grand-stephen",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.0.0\n"
]
}
],
"source": [
"import paddle\n",
"print(paddle.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "isolated-prize",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "romance-samuel",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/data/manifest.dev',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/aishell/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'linear',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "timely-bikini",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"from data_utils.dataset import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=False,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "organized-warrior",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if arr.dtype == np.object:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[14 , 34 , 322 , 233 , 0 , 0 ],\n",
" [238 , 38 , 122 , 164 , 0 , 0 ],\n",
" [8 , 52 , 49 , 42 , 0 , 0 ],\n",
" [109 , 47 , 146 , 193 , 210 , 479 ],\n",
" [3330, 1751, 208 , 1923, 0 , 0 ]])\n",
"test raw 大时代里的的\n",
"test raw 煲汤受宠的的\n",
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"test len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [4, 4, 4, 6, 4])\n",
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
" ...,\n",
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
"\n",
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
" ...,\n",
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
]
}
],
"source": [
" for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n",
" print('audio len', audio_len)\n",
" print('test len', text_len)\n",
" print('audio', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "confidential-radius",
"metadata": {},
"outputs": [],
"source": [
"# reader = batch_reader()\n",
"# audio, test , audio_len, text_len = reader.next()\n",
"# print('test', text)\n",
"# print('t len', text_len) #[B, T]\n",
"# print('audio len', audio_len)\n",
"# print(audio)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "future-vermont",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"煲汤受宠\n"
]
}
],
"source": [
"print(u'\\u7172\\u6c64\\u53d7\\u5ba0')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dental-sweden",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "sunrise-contact",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "hispanic-asthma",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "hearing-leadership",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "skilled-friday",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "copyrighted-measure",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"id": "employed-lightweight",
"metadata": {},
"outputs": [],
"source": [
"from model_utils.network import DeepSpeech2, DeepSpeech2Loss\n",
"\n",
"from data_utils.dataset import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=False,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)\n",
"\n",
"\n",
"import paddle\n",
"from paddle import nn\n",
"from paddle.nn import functional as F\n",
"from paddle.nn import initializer as I\n",
"\n",
"import math\n",
"\n",
"def brelu(x, t_min=0.0, t_max=24.0, name=None):\n",
" t_min = paddle.to_tensor(t_min)\n",
" t_max = paddle.to_tensor(t_max)\n",
" return x.maximum(t_min).minimum(t_max)\n",
"\n",
"def sequence_mask(x_len, max_len=None, dtype='float32'):\n",
" max_len = max_len or x_len.max()\n",
" x_len = paddle.unsqueeze(x_len, -1)\n",
" row_vector = paddle.arange(max_len)\n",
" mask = row_vector > x_len # maybe a bug\n",
" mask = paddle.cast(mask, dtype)\n",
" print(f'seq mask: {mask}')\n",
" return mask\n",
"\n",
"\n",
"class ConvBn(nn.Layer):\n",
" def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,\n",
" padding, act):\n",
"\n",
" super().__init__()\n",
" self.kernel_size = kernel_size\n",
" self.stride = stride\n",
" self.padding = padding\n",
"\n",
" self.conv = nn.Conv2D(\n",
" num_channels_in,\n",
" num_channels_out,\n",
" kernel_size=kernel_size,\n",
" stride=stride,\n",
" padding=padding,\n",
" weight_attr=None,\n",
" bias_attr=None,\n",
" data_format='NCHW')\n",
"\n",
" self.bn = nn.BatchNorm2D(\n",
" num_channels_out,\n",
" weight_attr=None,\n",
" bias_attr=None,\n",
" data_format='NCHW')\n",
" self.act = F.relu if act == 'relu' else brelu\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x(Tensor): audio, shape [B, C, D, T]\n",
" \"\"\"\n",
" x = self.conv(x)\n",
" x = self.bn(x)\n",
" x = self.act(x)\n",
"\n",
" x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]\n",
" ) // self.stride[1] + 1\n",
"\n",
" # reset padding part to 0\n",
" masks = sequence_mask(x_len) #[B, T]\n",
" masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]\n",
" x = x.multiply(masks)\n",
"\n",
" return x, x_len\n",
"\n",
"\n",
"class ConvStack(nn.Layer):\n",
" def __init__(self, feat_size, num_stacks):\n",
" super().__init__()\n",
" self.feat_size = feat_size # D\n",
" self.num_stacks = num_stacks\n",
"\n",
" self.conv_in = ConvBn(\n",
" num_channels_in=1,\n",
" num_channels_out=32,\n",
" kernel_size=(41, 11), #[D, T]\n",
" stride=(2, 3),\n",
" padding=(20, 5),\n",
" act='brelu')\n",
"\n",
" out_channel = 32\n",
" self.conv_stack = nn.Sequential([\n",
" ConvBn(\n",
" num_channels_in=32,\n",
" num_channels_out=out_channel,\n",
" kernel_size=(21, 11),\n",
" stride=(2, 1),\n",
" padding=(10, 5),\n",
" act='brelu') for i in range(num_stacks - 1)\n",
" ])\n",
"\n",
" # conv output feat_dim\n",
" output_height = (feat_size - 1) // 2 + 1\n",
" for i in range(self.num_stacks - 1):\n",
" output_height = (output_height - 1) // 2 + 1\n",
" self.output_height = out_channel * output_height\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x: shape [B, C, D, T]\n",
" x_len : shape [B]\n",
" \"\"\"\n",
" print(f\"conv in: {x_len}\")\n",
" x, x_len = self.conv_in(x, x_len)\n",
" for i, conv in enumerate(self.conv_stack):\n",
" print(f\"conv in: {x_len}\")\n",
" x, x_len = conv(x, x_len)\n",
" print(f\"conv out: {x_len}\")\n",
" return x, x_len\n",
" \n",
" \n",
"\n",
"class RNNCell(nn.RNNCellBase):\n",
" r\"\"\"\n",
" Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it \n",
" computes the outputs and updates states.\n",
" The formula used is as follows:\n",
" .. math::\n",
" h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})\n",
" y_{t} & = h_{t}\n",
" \n",
" where :math:`act` is for :attr:`activation`.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" hidden_size,\n",
" activation=\"tanh\",\n",
" weight_ih_attr=None,\n",
" weight_hh_attr=None,\n",
" bias_ih_attr=None,\n",
" bias_hh_attr=None,\n",
" name=None):\n",
" super().__init__()\n",
" std = 1.0 / math.sqrt(hidden_size)\n",
" self.weight_hh = self.create_parameter(\n",
" (hidden_size, hidden_size),\n",
" weight_hh_attr,\n",
" default_initializer=I.Uniform(-std, std))\n",
" # self.bias_ih = self.create_parameter(\n",
" # (hidden_size, ),\n",
" # bias_ih_attr,\n",
" # is_bias=True,\n",
" # default_initializer=I.Uniform(-std, std))\n",
" self.bias_ih = None\n",
" self.bias_hh = self.create_parameter(\n",
" (hidden_size, ),\n",
" bias_hh_attr,\n",
" is_bias=True,\n",
" default_initializer=I.Uniform(-std, std))\n",
"\n",
" self.hidden_size = hidden_size\n",
" if activation not in [\"tanh\", \"relu\", \"brelu\"]:\n",
" raise ValueError(\n",
" \"activation for SimpleRNNCell should be tanh or relu, \"\n",
" \"but get {}\".format(activation))\n",
" self.activation = activation\n",
" self._activation_fn = paddle.tanh \\\n",
" if activation == \"tanh\" \\\n",
" else F.relu\n",
" if activation == 'brelu':\n",
" self._activation_fn = brelu\n",
"\n",
" def forward(self, inputs, states=None):\n",
" if states is None:\n",
" states = self.get_initial_states(inputs, self.state_shape)\n",
" pre_h = states\n",
" i2h = inputs\n",
" if self.bias_ih is not None:\n",
" i2h += self.bias_ih\n",
" h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)\n",
" if self.bias_hh is not None:\n",
" h2h += self.bias_hh\n",
" h = self._activation_fn(i2h + h2h)\n",
" return h, h\n",
"\n",
" @property\n",
" def state_shape(self):\n",
" return (self.hidden_size, )\n",
"\n",
"\n",
"class GRUCellShare(nn.RNNCellBase):\n",
" r\"\"\"\n",
" Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, \n",
" it computes the outputs and updates states.\n",
" The formula for GRU used is as follows:\n",
" .. math::\n",
" r_{t} & = \\sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})\n",
" z_{t} & = \\sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})\n",
" \\widetilde{h}_{t} & = \\tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))\n",
" h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \\widetilde{h}_{t}\n",
" y_{t} & = h_{t}\n",
" \n",
" where :math:`\\sigma` is the sigmoid fucntion, and * is the elemetwise \n",
" multiplication operator.\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" input_size,\n",
" hidden_size,\n",
" weight_ih_attr=None,\n",
" weight_hh_attr=None,\n",
" bias_ih_attr=None,\n",
" bias_hh_attr=None,\n",
" name=None):\n",
" super().__init__()\n",
" std = 1.0 / math.sqrt(hidden_size)\n",
" self.weight_hh = self.create_parameter(\n",
" (3 * hidden_size, hidden_size),\n",
" weight_hh_attr,\n",
" default_initializer=I.Uniform(-std, std))\n",
" # self.bias_ih = self.create_parameter(\n",
" # (3 * hidden_size, ),\n",
" # bias_ih_attr,\n",
" # is_bias=True,\n",
" # default_initializer=I.Uniform(-std, std))\n",
" self.bias_ih = None\n",
" self.bias_hh = self.create_parameter(\n",
" (3 * hidden_size, ),\n",
" bias_hh_attr,\n",
" is_bias=True,\n",
" default_initializer=I.Uniform(-std, std))\n",
"\n",
" self.hidden_size = hidden_size\n",
" self.input_size = input_size\n",
" self._gate_activation = F.sigmoid\n",
" #self._activation = paddle.tanh\n",
" self._activation = F.relu\n",
"\n",
" def forward(self, inputs, states=None):\n",
" if states is None:\n",
" states = self.get_initial_states(inputs, self.state_shape)\n",
"\n",
" pre_hidden = states\n",
" x_gates = inputs\n",
" if self.bias_ih is not None:\n",
" x_gates = x_gates + self.bias_ih\n",
" h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)\n",
" if self.bias_hh is not None:\n",
" h_gates = h_gates + self.bias_hh\n",
"\n",
" x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)\n",
" h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)\n",
"\n",
" r = self._gate_activation(x_r + h_r)\n",
" z = self._gate_activation(x_z + h_z)\n",
" c = self._activation(x_c + r * h_c) # apply reset gate after mm\n",
" h = (pre_hidden - c) * z + c\n",
" # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru\n",
" #h = (1-z) * pre_hidden + z * c\n",
"\n",
" return h, h\n",
"\n",
" @property\n",
" def state_shape(self):\n",
" r\"\"\"\n",
" The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch\n",
" size would be automatically inserted into shape). The shape corresponds\n",
" to the shape of :math:`h_{t-1}`.\n",
" \"\"\"\n",
" return (self.hidden_size, )\n",
"\n",
"\n",
"class BiRNNWithBN(nn.Layer):\n",
" \"\"\"Bidirectonal simple rnn layer with sequence-wise batch normalization.\n",
" The batch normalization is only performed on input-state weights.\n",
"\n",
" :param name: Name of the layer parameters.\n",
" :type name: string\n",
" :param size: Dimension of RNN cells.\n",
" :type size: int\n",
" :param share_weights: Whether to share input-hidden weights between\n",
" forward and backward directional RNNs.\n",
" :type share_weights: bool\n",
" :return: Bidirectional simple rnn layer.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, share_weights):\n",
" super().__init__()\n",
" self.share_weights = share_weights\n",
" if self.share_weights:\n",
" #input-hidden weights shared between bi-directional rnn.\n",
" self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" # batch norm is only performed on input-state projection\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = self.fw_fc\n",
" self.bw_bn = self.fw_bn\n",
" else:\n",
" self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n",
" self.bw_bn = nn.BatchNorm1D(\n",
" h_size, bias_attr=None, data_format='NLC')\n",
"\n",
" self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n",
" self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n",
" self.fw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n",
" self.bw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n",
"\n",
" def forward(self, x, x_len):\n",
" # x, shape [B, T, D]\n",
" fw_x = self.fw_bn(self.fw_fc(x))\n",
" bw_x = self.bw_bn(self.bw_fc(x))\n",
" fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n",
" bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n",
" x = paddle.concat([fw_x, bw_x], axis=-1)\n",
" return x, x_len\n",
"\n",
"\n",
"class BiGRUWithBN(nn.Layer):\n",
" \"\"\"Bidirectonal gru layer with sequence-wise batch normalization.\n",
" The batch normalization is only performed on input-state weights.\n",
"\n",
" :param name: Name of the layer.\n",
" :type name: string\n",
" :param input: Input layer.\n",
" :type input: Variable\n",
" :param size: Dimension of GRU cells.\n",
" :type size: int\n",
" :param act: Activation type.\n",
" :type act: string\n",
" :return: Bidirectional GRU layer.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, act):\n",
" super().__init__()\n",
" hidden_size = h_size * 3\n",
" self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n",
" self.fw_bn = nn.BatchNorm1D(\n",
" hidden_size, bias_attr=None, data_format='NLC')\n",
" self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n",
" self.bw_bn = nn.BatchNorm1D(\n",
" hidden_size, bias_attr=None, data_format='NLC')\n",
"\n",
" self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n",
" self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n",
" self.fw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n",
" self.bw_rnn = nn.RNN(\n",
" self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n",
"\n",
" def forward(self, x, x_len):\n",
" # x, shape [B, T, D]\n",
" fw_x = self.fw_bn(self.fw_fc(x))\n",
" bw_x = self.bw_bn(self.bw_fc(x))\n",
" fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n",
" bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n",
" x = paddle.concat([fw_x, bw_x], axis=-1)\n",
" return x, x_len\n",
"\n",
"\n",
"class RNNStack(nn.Layer):\n",
" \"\"\"RNN group with stacked bidirectional simple RNN or GRU layers.\n",
"\n",
" :param input: Input layer.\n",
" :type input: Variable\n",
" :param size: Dimension of RNN cells in each layer.\n",
" :type size: int\n",
" :param num_stacks: Number of stacked rnn layers.\n",
" :type num_stacks: int\n",
" :param use_gru: Use gru if set True. Use simple rnn if set False.\n",
" :type use_gru: bool\n",
" :param share_rnn_weights: Whether to share input-hidden weights between\n",
" forward and backward directional RNNs.\n",
" It is only available when use_gru=False.\n",
" :type share_weights: bool\n",
" :return: Output layer of the RNN group.\n",
" :rtype: Variable\n",
" \"\"\"\n",
"\n",
" def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):\n",
" super().__init__()\n",
" self.rnn_stacks = nn.LayerList()\n",
" for i in range(num_stacks):\n",
" if use_gru:\n",
" #default:GRU using tanh\n",
" self.rnn_stacks.append(\n",
" BiGRUWithBN(i_size=i_size, h_size=h_size, act=\"relu\"))\n",
" else:\n",
" self.rnn_stacks.append(\n",
" BiRNNWithBN(\n",
" i_size=i_size,\n",
" h_size=h_size,\n",
" share_weights=share_rnn_weights))\n",
" i_size = h_size * 2\n",
"\n",
" def forward(self, x, x_len):\n",
" \"\"\"\n",
" x: shape [B, T, D]\n",
" x_len: shpae [B]\n",
" \"\"\"\n",
" for i, rnn in enumerate(self.rnn_stacks):\n",
" x, x_len = rnn(x, x_len)\n",
" masks = sequence_mask(x_len) #[B, T]\n",
" masks = masks.unsqueeze(-1) # [B, T, 1]\n",
" x = x.multiply(masks)\n",
" return x, x_len\n",
"\n",
" \n",
"class DeepSpeech2Test(DeepSpeech2):\n",
" def __init__(self,\n",
" feat_size,\n",
" dict_size,\n",
" num_conv_layers=2,\n",
" num_rnn_layers=3,\n",
" rnn_size=256,\n",
" use_gru=False,\n",
" share_rnn_weights=True):\n",
" super().__init__(feat_size,\n",
" dict_size,\n",
" num_conv_layers=2,\n",
" num_rnn_layers=3,\n",
" rnn_size=256,\n",
" use_gru=False,\n",
" share_rnn_weights=True)\n",
" self.feat_size = feat_size # 161 for linear\n",
" self.dict_size = dict_size\n",
"\n",
" self.conv = ConvStack(feat_size, num_conv_layers)\n",
" \n",
"# self.fc = nn.Linear(1312, dict_size + 1)\n",
"\n",
" i_size = self.conv.output_height # H after conv stack\n",
" self.rnn = RNNStack(\n",
" i_size=i_size,\n",
" h_size=rnn_size,\n",
" num_stacks=num_rnn_layers,\n",
" use_gru=use_gru,\n",
" share_rnn_weights=share_rnn_weights)\n",
" \n",
" self.fc = nn.Linear(rnn_size * 2, dict_size + 1)\n",
" \n",
" def infer(self, audio, audio_len):\n",
" # [B, D, T] -> [B, C=1, D, T]\n",
" audio = audio.unsqueeze(1)\n",
"\n",
" # convolution group\n",
" x, audio_len = self.conv(audio, audio_len)\n",
" print('conv out', x.shape)\n",
"\n",
" # convert data from convolution feature map to sequence of vectors\n",
" B, C, D, T = paddle.shape(x)\n",
" x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]\n",
" x = x.reshape([B, T, C * D]) #[B, T, C*D]\n",
" print('rnn input', x.shape)\n",
"\n",
" # remove padding part\n",
" x, audio_len = self.rnn(x, audio_len) #[B, T, D]\n",
" print('rnn output', x.shape)\n",
"\n",
" logits = self.fc(x) #[B, T, V + 1]\n",
"\n",
" #ctcdecoder need probs, not log_probs\n",
" probs = F.softmax(logits)\n",
"\n",
" return logits, probs, audio_len\n",
"\n",
" def forward(self, audio, audio_len, text, text_len):\n",
" \"\"\"\n",
" audio: shape [B, D, T]\n",
" text: shape [B, T]\n",
" audio_len: shape [B]\n",
" text_len: shape [B]\n",
" \"\"\"\n",
" return self.infer(audio, audio_len)\n",
" \n",
"\n",
"feat_dim=161\n",
"\n",
"model = DeepSpeech2Test(\n",
" feat_size=feat_dim,\n",
" dict_size=batch_reader.dataset.vocab_size,\n",
" num_conv_layers=args.num_conv_layers,\n",
" num_rnn_layers=args.num_rnn_layers,\n",
" rnn_size=1024,\n",
" use_gru=args.use_gru,\n",
" share_rnn_weights=args.share_rnn_weights,\n",
" )\n",
"dp_model = model\n",
"#dp_model = paddle.DataParallel(model)\n",
"\n",
"loss_fn = DeepSpeech2Loss(batch_reader.dataset.vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "divided-incentive",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 22,
"id": "discrete-conjunction",
"metadata": {},
"outputs": [],
"source": [
"audio, audio_len, text, text_len = None, None, None, None\n",
"\n",
"for idx, inputs in enumerate(batch_reader):\n",
" audio, audio_len, text, text_len = inputs\n",
"# print(idx)\n",
"# print('a', audio.shape, audio.place)\n",
"# print('t', text)\n",
"# print('al', audio_len)\n",
"# print('tl', text_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "protected-announcement",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"conv in: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"conv in: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"conv out: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"conv out [5, 32, 41, 62]\n",
"rnn input [5, 62, 1312]\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
" return (isinstance(seq, collections.Sequence) and\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
" [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
" [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
"rnn output [5, 62, 2048]\n",
"logits len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [55, 56, 60, 62, 62])\n",
"loss Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n",
" [2316.82153320])\n"
]
}
],
"source": [
"outputs = dp_model(audio, audio_len, text, text_len)\n",
"logits, _, logits_len = outputs\n",
"print('logits len', logits_len)\n",
"loss = loss_fn.forward(logits, text, logits_len, text_len)\n",
"print('loss', loss)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "universal-myrtle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: None\n",
"param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: None\n",
"param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: None\n",
"param grad: fc.bias: shape: [4299] stop_grad: False grad: None\n"
]
}
],
"source": [
"for n, p in dp_model.named_parameters():\n",
" print(\n",
" f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "referenced-double",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: [[[[ 2.1243238 1.696022 3.770659 ... 5.234652 5.4865217\n",
" 4.757795 ]\n",
" [ 2.651376 2.3109848 4.428488 ... 5.353201 8.703288\n",
" 5.1787405 ]\n",
" [ 2.7511077 1.8823049 2.1875212 ... 3.4821286 6.386543\n",
" 3.5026932 ]\n",
" ...\n",
" [ 1.9173846 1.8623551 0.5601456 ... 2.8375719 3.8496673\n",
" 2.359191 ]\n",
" [ 2.3827765 2.497965 1.5914664 ... 2.220721 3.4617734\n",
" 4.829253 ]\n",
" [ 1.6855702 1.5040786 1.8793598 ... 4.0773935 3.176893\n",
" 3.7477999 ]]]\n",
"\n",
"\n",
" [[[ 1.8451455 2.0091445 1.5225713 ... 1.524528 0.17764974\n",
" 1.0245132 ]\n",
" [ 1.9388857 1.3873467 2.044691 ... 0.92544 -0.9746763\n",
" -0.41603735]\n",
" [ 2.6814485 2.6096234 1.6802506 ... 1.902397 1.6837387\n",
" -0.96788657]\n",
" ...\n",
" [ 4.3675485 1.9822174 1.1695029 ... 1.4672399 3.2029557\n",
" 2.6364415 ]\n",
" [ 3.2536 1.1792442 -0.5618002 ... 2.101127 1.904225\n",
" 3.3839993 ]\n",
" [ 1.9118482 1.0651072 0.5409893 ... 2.6783593 1.6871439\n",
" 4.1078367 ]]]\n",
"\n",
"\n",
" [[[-4.412424 -1.7111907 -1.7722387 ... -4.3383503 -6.2393785\n",
" -6.139402 ]\n",
" [-2.260428 -1.0250616 -2.0550888 ... -5.353946 -4.29947\n",
" -6.158736 ]\n",
" [-1.4927872 0.7552787 -0.0702923 ... -4.485656 -4.0794134\n",
" -5.416684 ]\n",
" ...\n",
" [ 2.9100134 4.156195 4.357041 ... -3.569804 -1.8634341\n",
" -0.8772557 ]\n",
" [ 1.6895763 3.4314504 4.1192107 ... -1.380024 -2.3234155\n",
" -3.6650617 ]\n",
" [ 2.4190075 1.007498 3.1173465 ... -0.96318084 -3.6175003\n",
" -2.5240796 ]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[-0.6865506 -0.60106415 -1.5555015 ... 2.0853553 1.900961\n",
" 2.101063 ]\n",
" [-0.31686288 -1.4362946 -1.4929098 ... 0.15085456 1.4540495\n",
" 1.4128599 ]\n",
" [-0.57852304 -0.8204216 -2.3264258 ... 1.4970423 0.54599845\n",
" 1.6222539 ]\n",
" ...\n",
" [ 0.32624918 0.96004546 -0.7476514 ... 2.2786083 2.1000178\n",
" 2.7494807 ]\n",
" [-1.6967826 -0.78979015 -1.8424999 ... 1.0620685 2.0544293\n",
" 2.2483966 ]\n",
" [ 0.8192332 2.601636 -2.6636481 ... 0.26625186 1.7610842\n",
" 1.7467536 ]]]\n",
"\n",
"\n",
" [[[ 0.9140297 0.42424175 1.4352363 ... -2.3022954 -3.001058\n",
" -2.6987422 ]\n",
" [ 0.4491998 -0.10698095 1.5089144 ... -3.2831016 -3.6055021\n",
" -3.6595795 ]\n",
" [ 2.6818252 -1.5750014 -0.34812498 ... -4.4137015 -4.250422\n",
" -3.481941 ]\n",
" ...\n",
" [ 1.4232106 2.9689102 3.9547806 ... -0.481165 0.28190404\n",
" -1.2167063 ]\n",
" [ 2.2297084 4.8198485 4.2857304 ... 0.57483846 1.4093391\n",
" 0.0715822 ]\n",
" [ 1.679745 4.768068 5.416195 ... 0.17254728 0.4623217\n",
" 1.4772662 ]]]\n",
"\n",
"\n",
" [[[-2.0860114 -2.9508173 -1.4945896 ... -4.067145 -2.5652342\n",
" -3.5771027 ]\n",
" [-2.697845 -1.9273603 -2.3885014 ... -2.196533 -2.8573706\n",
" -2.0113711 ]\n",
" [-2.413383 -2.7204053 -1.0502659 ... -3.001385 -3.36447\n",
" -4.3225455 ]\n",
" ...\n",
" [ 1.2754489 0.9560999 1.5239805 ... -0.0105865 -1.00876\n",
" 2.6247358 ]\n",
" [ 1.1965859 1.0378222 1.1025598 ... -0.5394704 0.49838027\n",
" -0.9618193 ]\n",
" [ 1.1361816 1.3232857 0.687318 ... -0.23925456 -0.43679112\n",
" -0.79297894]]]]\n",
"param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: [ 5.9604645e-07 -3.9339066e-06 -1.0728836e-06 -1.6689301e-06\n",
" 1.1920929e-06 -2.5033951e-06 -2.3841858e-07 4.7683716e-07\n",
" 4.2915344e-06 -1.9073486e-06 -1.9073486e-06 3.0994415e-06\n",
" -2.6822090e-06 3.3378601e-06 -4.2915344e-06 5.2452087e-06\n",
" 3.8146973e-06 2.3841858e-07 7.1525574e-07 -3.6954880e-06\n",
" 2.0563602e-06 -2.6226044e-06 3.0994415e-06 -3.5762787e-07\n",
" -4.7683716e-06 1.2218952e-06 3.3378601e-06 -2.5629997e-06\n",
" 2.3841858e-07 -1.7881393e-06 4.7683716e-07 -2.7418137e-06]\n",
"param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: [ 2.363316 3.286464 1.9607866 -1.6367784 -1.6325372 -1.7729434\n",
" -0.9261875 2.0950415 0.1155543 -0.8857083 0.70079553 0.33920464\n",
" 2.6953902 -0.64524114 0.8845749 -1.2271115 0.6578167 -2.939814\n",
" 5.5728893 -1.0917969 0.01470797 1.395206 4.8009634 -0.744532\n",
" 0.944651 -1.092311 1.4877632 -3.042566 0.51686054 -5.4768667\n",
" -5.628145 -1.0894046 ]\n",
"param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: [ 1.5193373 1.8838218 3.7722278 0.28052303 0.5386534 -0.44620085\n",
" -1.6977876 3.115642 0.03312349 -2.9121587 3.8925257 0.2288351\n",
" -2.273387 -1.3597974 4.3708124 -0.23374033 0.116272 -0.7064927\n",
" 6.5267463 -1.5318865 1.0288429 0.7928574 -0.24655592 -2.1116853\n",
" 2.922772 -3.3462617 1.7016437 -3.5471547 0.29777628 -3.2820854\n",
" -4.116946 -0.9909375 ]\n",
"param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: [[[[ 6.20494843e-01 5.95983505e-01 -1.48909020e+00 ... -6.86620831e-01\n",
" 6.71104014e-01 -1.95339048e+00]\n",
" [-3.91837955e-03 1.27062631e+00 -1.63248098e+00 ... 1.07290137e+00\n",
" -9.42245364e-01 -3.34277248e+00]\n",
" [ 2.41821265e+00 2.36212373e-01 -1.84433365e+00 ... 1.23182368e+00\n",
" 1.36039746e+00 -2.94621849e+00]\n",
" ...\n",
" [ 1.55153418e+00 7.25861669e-01 2.08785534e+00 ... -6.40172660e-01\n",
" -3.23889256e-02 -2.30832791e+00]\n",
" [ 3.69824195e+00 1.27163112e-01 4.09263194e-01 ... -8.60729575e-01\n",
" -3.51897454e+00 -2.10093403e+00]\n",
" [-4.94779050e-01 -3.74262631e-01 -1.19801068e+00 ... -2.05930543e+00\n",
" -7.38576293e-01 -9.44581270e-01]]\n",
"\n",
" [[-2.04341412e+00 -3.70606273e-01 -1.40429378e+00 ... -1.71711946e+00\n",
" -4.09437418e-01 -1.74107194e+00]\n",
" [-8.72247815e-01 -1.06301677e+00 -9.19306517e-01 ... -2.98976970e+00\n",
" -3.03250861e+00 -2.37099743e+00]\n",
" [-5.00457406e-01 -1.11882675e+00 -5.91526508e-01 ... 4.23921436e-01\n",
" -2.08650708e+00 -1.82109618e+00]\n",
" ...\n",
" [ 2.07773042e+00 1.40735030e-01 -2.60543615e-01 ... -1.55956164e-01\n",
" -1.31862307e+00 -2.07174897e+00]\n",
" [ 7.95007765e-01 1.14988625e-01 -1.43308258e+00 ... 8.29253554e-01\n",
" -9.57888126e-01 -3.82121086e-01]\n",
" [ 8.34397674e-02 1.38636863e+00 -1.21593380e+00 ... -2.65783578e-01\n",
" 1.78124309e-02 -3.40287232e+00]]\n",
"\n",
" [[ 6.27344131e-01 5.71699142e-02 -3.58010936e+00 ... -4.53077674e-01\n",
" 1.65331578e+00 2.58466601e-02]\n",
" [ 2.66681361e+00 2.02069378e+00 -1.52052927e+00 ... 2.94914508e+00\n",
" 1.94632411e+00 -1.06698799e+00]\n",
" [ 1.57839453e+00 -1.03649735e-01 -4.22528505e+00 ... 2.28863955e+00\n",
" 4.27859402e+00 3.66381669e+00]\n",
" ...\n",
" [-2.44603205e+00 -2.09621000e+00 -2.57623529e+00 ... 9.00211930e-01\n",
" 4.30536079e+00 -2.49779320e+00]\n",
" [-2.52187514e+00 -3.36546659e+00 -1.26748765e+00 ... 8.11533451e-01\n",
" 2.55930424e-01 4.50821817e-02]\n",
" [-3.40082574e+00 -3.26924801e+00 -5.86932135e+00 ... -1.18203712e+00\n",
" 1.09565187e+00 -4.96661961e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 8.20469666e+00 6.96195841e+00 2.73753977e+00 ... 8.34498823e-01\n",
" 2.56748104e+00 1.67592216e+00]\n",
" [ 9.85801792e+00 8.81465149e+00 6.09280396e+00 ... 1.42389655e+00\n",
" 2.92086434e+00 2.08308399e-01]\n",
" [ 8.00702763e+00 7.97301006e+00 4.64527416e+00 ... 8.61916900e-01\n",
" 3.55370259e+00 4.75085378e-01]\n",
" ...\n",
" [ 5.61662769e+00 -4.72857296e-01 -1.04519971e-01 ... -4.03000236e-01\n",
" -1.66419971e+00 -1.70375630e-01]\n",
" [ 4.52409792e+00 -3.70670676e-01 4.54190969e-02 ... -8.20453286e-01\n",
" 9.49141383e-02 8.88008535e-01]\n",
" [ 3.27219462e+00 8.93201411e-01 1.94810414e+00 ... -2.86915004e-02\n",
" 1.93200278e+00 8.19505215e-01]]\n",
"\n",
" [[ 5.84066296e+00 6.72855520e+00 5.21399307e+00 ... 4.55058670e+00\n",
" 3.19132543e+00 3.17435169e+00]\n",
" [ 6.04594421e+00 6.88997173e+00 5.00542831e+00 ... 2.23561144e+00\n",
" 2.76059532e+00 4.83479440e-01]\n",
" [ 5.36118126e+00 4.13896275e+00 3.68701124e+00 ... 3.64462805e+00\n",
" 2.80596399e+00 1.52781498e+00]\n",
" ...\n",
" [ 2.87856674e+00 5.84320784e-01 1.74297714e+00 ... 2.83938944e-01\n",
" -2.26546407e-01 -1.18434143e+00]\n",
" [ 2.08510804e+00 1.74915957e+00 1.58637917e+00 ... 6.41967297e-01\n",
" -1.31319761e-01 -3.85830402e-01]\n",
" [ 4.41666174e+00 2.58244562e+00 2.97712159e+00 ... 1.42317235e-01\n",
" 1.68037796e+00 -6.50003672e-01]]\n",
"\n",
" [[ 1.05511594e+00 6.74880028e-01 -7.64639139e-01 ... -2.15282440e-01\n",
" 2.07197094e+00 4.48752761e-01]\n",
" [ 2.12095881e+00 3.44118834e+00 1.61375272e+00 ... -1.18487728e+00\n",
" 1.88659012e+00 1.48252523e+00]\n",
" [ 8.33427787e-01 4.35035896e+00 -3.59877385e-02 ... 8.70242774e-01\n",
" 3.75945044e+00 -3.09408635e-01]\n",
" ...\n",
" [ 5.08510351e+00 4.73114061e+00 1.97346115e+00 ... -2.25924397e+00\n",
" -1.26373076e+00 -1.37826729e+00]\n",
" [ 6.17275095e+00 4.16016817e+00 3.15675950e+00 ... -2.02416754e+00\n",
" 1.50002241e-02 1.84633851e+00]\n",
" [ 7.32995272e+00 5.34601831e+00 4.58857203e+00 ... -1.88874304e+00\n",
" 1.53240371e+00 7.47349262e-02]]]\n",
"\n",
"\n",
" [[[-1.80918843e-01 -2.52616453e+00 -2.78145695e+00 ... 1.44283652e+00\n",
" -1.08945215e+00 4.19084758e-01]\n",
" [-9.66833949e-01 -2.41106153e+00 -3.48886085e+00 ... -1.87193304e-01\n",
" 8.21905077e-01 1.89097953e+00]\n",
" [-1.59118319e+00 -2.56997013e+00 -3.10426521e+00 ... 2.05900550e+00\n",
" -2.78253704e-01 6.96343541e-01]\n",
" ...\n",
" [ 6.66302443e-02 -2.00887346e+00 -3.17550874e+00 ... 7.97579706e-01\n",
" -9.71581042e-02 1.71877682e+00]\n",
" [-8.01679730e-01 -2.02678037e+00 -3.21915555e+00 ... 8.35528374e-01\n",
" -1.15296638e+00 4.35728967e-01]\n",
" [ 1.45292446e-01 -2.15479851e+00 -1.51839817e+00 ... -3.07936192e-01\n",
" -5.39051890e-01 1.13107657e+00]]\n",
"\n",
" [[-2.43341160e+00 -3.35346818e+00 -9.87014294e-01 ... 1.34049034e+00\n",
" 2.95773447e-02 1.27177119e+00]\n",
" [-2.61602497e+00 -9.76761580e-01 -2.52060473e-01 ... -1.38134825e+00\n",
" 3.85564029e-01 4.57195908e-01]\n",
" [-2.23676014e+00 -4.00404739e+00 -2.23409963e+00 ... -1.41846514e+00\n",
" -6.58698231e-02 -3.61778140e-01]\n",
" ...\n",
" [-1.13604403e+00 -6.03917837e-02 -4.95491922e-01 ... 2.14673686e+00\n",
" 1.21484184e+00 2.22764325e+00]\n",
" [-1.05162430e+00 -1.59828448e+00 3.15489501e-01 ... 2.28046751e+00\n",
" 2.39702511e+00 2.43942714e+00]\n",
" [-1.27370405e+00 -2.05736399e-01 -1.12124372e+00 ... 2.21597219e+00\n",
" 2.50086927e+00 1.91134131e+00]]\n",
"\n",
" [[-4.53170598e-01 -1.59644139e+00 -3.63470483e+00 ... -4.35066032e+00\n",
" -3.79540777e+00 -1.09796596e+00]\n",
" [-2.21036464e-01 -2.53353834e+00 -1.28269875e+00 ... -3.38615727e+00\n",
" -2.59143281e+00 7.74220943e-01]\n",
" [-6.89323783e-01 -1.44375205e+00 6.66438341e-02 ... -1.30736077e+00\n",
" -1.23293114e+00 1.58148706e+00]\n",
" ...\n",
" [ 1.63751483e+00 -4.08427984e-01 -8.15176964e-01 ... 3.70807743e+00\n",
" 2.04232907e+00 1.97716308e+00]\n",
" [ 2.13261342e+00 1.85947633e+00 -8.06532025e-01 ... 1.98311245e+00\n",
" 2.27003932e+00 -1.11734614e-01]\n",
" [ 1.28702402e+00 3.98628891e-01 -1.63712263e+00 ... 8.00528765e-01\n",
" 5.78273535e-01 -2.59924948e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 3.96233416e+00 4.66794682e+00 1.39437711e+00 ... 7.52061129e-01\n",
" -1.53534544e+00 -6.67162359e-01]\n",
" [ 2.33841681e+00 3.35811281e+00 9.80114818e-01 ... 1.48806703e+00\n",
" 2.68609226e-01 -1.35124445e+00]\n",
" [ 2.08177710e+00 4.28519583e+00 1.52450514e+00 ... 7.45321214e-01\n",
" -5.04359961e-01 -1.81241560e+00]\n",
" ...\n",
" [ 2.95398951e-01 4.30877179e-01 -2.03731894e+00 ... -4.20221925e-01\n",
" 3.29260826e-01 5.83679557e-01]\n",
" [ 1.30742240e+00 -6.32183790e-01 -3.13741422e+00 ... 9.63868052e-02\n",
" 2.91730791e-01 1.33400351e-01]\n",
" [ 5.43292165e-01 -2.83665359e-01 -1.88138187e+00 ... 2.15468198e-01\n",
" 4.90157723e-01 2.40562439e+00]]\n",
"\n",
" [[ 1.57632053e+00 6.27885723e+00 2.87853765e+00 ... 3.07016110e+00\n",
" 1.91490650e+00 1.76274943e+00]\n",
" [ 2.57776356e+00 4.07256317e+00 2.52231169e+00 ... 4.09494352e+00\n",
" 2.53548074e+00 2.44395185e+00]\n",
" [ 2.43037057e+00 4.35728836e+00 1.96233964e+00 ... 2.26702976e+00\n",
" 2.94634581e+00 2.21452284e+00]\n",
" ...\n",
" [-2.72509992e-01 -8.41220498e-01 -1.89133918e+00 ... -1.80079627e+00\n",
" -2.00367713e+00 -7.09145784e-01]\n",
" [ 8.21575999e-01 -1.13323164e+00 -2.62418866e+00 ... -2.38889670e+00\n",
" -7.83945560e-01 -1.01922750e-01]\n",
" [-1.14730227e+00 -1.42182577e+00 -2.00993991e+00 ... -2.11025667e+00\n",
" 1.60286129e-02 -7.26446986e-01]]\n",
"\n",
" [[ 4.20389509e+00 3.75917768e+00 4.97653627e+00 ... 1.23642838e+00\n",
" 8.52760911e-01 1.27920091e-01]\n",
" [ 5.29409122e+00 5.29002380e+00 3.96404648e+00 ... 1.91227329e+00\n",
" 3.97556186e-01 1.69182217e+00]\n",
" [ 4.60112572e+00 4.12772799e+00 2.10280085e+00 ... 3.24303842e+00\n",
" -1.07720590e+00 -3.81854475e-01]\n",
" ...\n",
" [ 1.81884170e-02 -3.11472058e+00 -8.23525012e-01 ... -2.40161085e+00\n",
" -4.48192549e+00 -6.14600539e-01]\n",
" [ 1.16305006e+00 -1.15409636e+00 -3.48765063e+00 ... -1.97504926e+00\n",
" -4.44984436e+00 -2.28429958e-01]\n",
" [ 1.29197860e+00 6.17720246e-01 -5.87171853e-01 ... -1.35258228e-01\n",
" -1.29259872e+00 1.30360842e-01]]]\n",
"\n",
"\n",
" [[[-1.26687372e+00 -2.33633637e+00 -1.49625254e+00 ... 2.52396107e+00\n",
" -6.68072224e-01 -1.13282454e+00]\n",
" [-1.34229445e+00 -2.87080932e+00 -2.57388353e+00 ... -8.75385761e-01\n",
" -1.00205469e+00 -3.58956242e+00]\n",
" [-9.49853599e-01 -5.78684711e+00 -3.52962446e+00 ... 8.88233304e-01\n",
" 2.25133196e-01 -1.02802217e+00]\n",
" ...\n",
" [-7.38113701e-01 -3.47510982e+00 -3.23011065e+00 ... -1.25624001e+00\n",
" -1.63268471e+00 6.00247443e-01]\n",
" [-2.29733467e+00 -5.72547615e-01 -1.98301303e+00 ... -1.90137398e+00\n",
" -1.47013855e+00 -1.45779204e+00]\n",
" [-2.24628520e+00 -3.36337948e+00 -3.91878939e+00 ... -1.53652275e+00\n",
" -1.36285520e+00 -1.68160331e+00]]\n",
"\n",
" [[-8.11348319e-01 -7.17824280e-01 -1.02243233e+00 ... -2.69050407e+00\n",
" -2.32403350e+00 -4.25943947e+00]\n",
" [-2.35056520e+00 -2.35941172e+00 -1.24398732e+00 ... -2.08313870e+00\n",
" -1.16508257e+00 -1.30353463e+00]\n",
" [-2.25146723e+00 -1.94972813e+00 -1.13295293e+00 ... -2.61496377e+00\n",
" -1.91106403e+00 -1.07801402e+00]\n",
" ...\n",
" [-2.67012739e+00 -3.20916414e+00 -2.41768575e+00 ... 2.65138328e-01\n",
" -5.27612507e-01 1.44604075e+00]\n",
" [-3.54237866e+00 -3.62832785e+00 -2.40270257e+00 ... -9.76106226e-02\n",
" 4.67946082e-01 -7.24248111e-01]\n",
" [-2.49844384e+00 -3.42463255e+00 -2.99040008e+00 ... 4.28889185e-01\n",
" -7.51657963e-01 -1.00530767e+00]]\n",
"\n",
" [[-8.42589438e-02 1.42022014e-01 -8.51281703e-01 ... 4.21745628e-01\n",
" -2.35717297e-02 -1.71374834e+00]\n",
" [-1.05496287e+00 3.82416457e-01 -4.40595537e-01 ... 1.03381336e-01\n",
" -1.41204190e+00 -7.58325040e-01]\n",
" [-2.28930283e+00 -2.03857040e+00 -9.16261196e-01 ... -3.94939929e-01\n",
" -1.07798588e+00 -1.48433352e+00]\n",
" ...\n",
" [-3.11473966e-01 -1.40877593e+00 -2.42908645e+00 ... 7.88682699e-01\n",
" 1.24199319e+00 1.89949930e-01]\n",
" [ 5.44084549e-01 -1.02425671e+00 -1.53991556e+00 ... -4.36764538e-01\n",
" -5.78772545e-01 2.62665659e-01]\n",
" [ 1.26812792e+00 -9.89493608e-01 -1.47972977e+00 ... 2.21440494e-02\n",
" 2.79776216e-01 7.63269484e-01]]\n",
"\n",
" ...\n",
"\n",
" [[ 6.02095068e-01 5.93243122e-01 -1.06838238e+00 ... 3.56546330e+00\n",
" 1.16390383e+00 -1.47593319e-01]\n",
" [ 1.80458140e+00 1.68401957e+00 4.17516947e-01 ... 3.33444500e+00\n",
" 1.89411759e+00 1.03220642e-01]\n",
" [ 2.74264169e+00 2.92038846e+00 1.00775683e+00 ... 3.53285050e+00\n",
" 2.07282662e+00 -2.56800652e-01]\n",
" ...\n",
" [ 4.88933468e+00 3.72433925e+00 3.58677816e+00 ... 1.98363388e+00\n",
" 1.80851030e+00 8.32634747e-01]\n",
" [ 4.01546288e+00 4.78934765e+00 2.94778132e+00 ... 2.99637699e+00\n",
" 1.30439472e+00 3.61029744e-01]\n",
" [ 3.13628030e+00 2.01894832e+00 2.82585931e+00 ... 2.54264188e+00\n",
" -9.16651785e-02 9.93353873e-02]]\n",
"\n",
" [[ 2.35585642e+00 8.42678428e-01 1.57331872e+00 ... 3.65935063e+00\n",
" 3.94066262e+00 4.89832020e+00]\n",
" [ 1.85791731e+00 1.34373701e+00 1.30812299e+00 ... 2.71434736e+00\n",
" 3.22004294e+00 2.99872303e+00]\n",
" [ 1.67675853e+00 -4.05569375e-02 1.85539150e+00 ... 3.73934364e+00\n",
" 2.98195982e+00 3.37315011e+00]\n",
" ...\n",
" [ 2.14539170e+00 2.86586595e+00 2.20222116e+00 ... 1.20492995e+00\n",
" 2.13971066e+00 1.94932449e+00]\n",
" [ 4.68422651e+00 3.80044746e+00 4.23209000e+00 ... 2.40658951e+00\n",
" 2.29117441e+00 2.52368808e+00]\n",
" [ 3.10694575e+00 2.49402595e+00 4.53786707e+00 ... 9.08902645e-01\n",
" 1.86903965e+00 2.27776885e+00]]\n",
"\n",
" [[ 1.45200038e+00 5.17961740e-01 -1.58403587e+00 ... 5.07019472e+00\n",
" 7.87163258e-01 1.20610237e+00]\n",
" [ 3.39321136e+00 2.21043849e+00 -6.31202877e-01 ... 4.97822762e+00\n",
" 9.66498017e-01 1.18883348e+00]\n",
" [ 1.20627856e+00 1.82759428e+00 5.91053367e-01 ... 4.14318657e+00\n",
" 5.25399208e-01 -1.16850233e+00]\n",
" ...\n",
" [ 1.05183899e+00 5.80030501e-01 1.89724147e+00 ... 2.54626465e+00\n",
" -1.49128008e+00 -1.85064209e+00]\n",
" [ 1.50983357e+00 2.85973406e+00 2.61224055e+00 ... 4.83481932e+00\n",
" 9.67048705e-02 -4.37043965e-01]\n",
" [ 2.57720876e+00 2.09961963e+00 4.11754288e-02 ... 3.80421424e+00\n",
" -7.83308804e-01 -1.64871216e+00]]]\n",
"\n",
"\n",
" ...\n",
"\n",
"\n",
" [[[-1.16345096e+00 -2.53971386e+00 -8.99101734e-01 ... -4.35583591e-01\n",
" -1.29671764e+00 -1.61429560e+00]\n",
" [ 3.72841507e-01 3.45808208e-01 -1.82167351e+00 ... -2.14515448e+00\n",
" -1.26383066e+00 -2.27464601e-01]\n",
" [ 1.58568513e+00 2.58181524e+00 1.86554670e+00 ... -1.10401320e+00\n",
" -3.68550658e-01 -2.58849680e-01]\n",
" ...\n",
" [-9.15827155e-01 -1.25424683e+00 -4.04716206e+00 ... 2.13138080e+00\n",
" 2.67662477e+00 2.31014514e+00]\n",
" [-3.19453120e-01 -6.71132684e-01 -1.51378751e+00 ... 1.86080432e+00\n",
" 2.77418542e+00 1.22875953e+00]\n",
" [-1.20453942e+00 -3.93669218e-01 -1.51751983e+00 ... 1.17620552e+00\n",
" 1.95602298e+00 7.64306366e-01]]\n",
"\n",
" [[-8.73186827e-01 -2.12537169e+00 -1.91664994e+00 ... -2.90821463e-01\n",
" 1.90896463e+00 8.02283168e-01]\n",
" [-1.06389821e+00 -2.15300727e+00 -1.82113051e+00 ... -4.34280694e-01\n",
" 1.53455496e+00 1.94702053e+00]\n",
" [-2.08403468e+00 -4.72900331e-01 -1.10610819e+00 ... -8.79420400e-01\n",
" 7.79394627e-01 2.02670670e+00]\n",
" ...\n",
" [-4.28208113e-01 -7.90894389e-01 -1.06713009e+00 ... 1.12579381e+00\n",
" 9.61961091e-01 1.40342009e+00]\n",
" [ 4.40416574e-01 -1.65901780e-02 -1.05338669e+00 ... 1.40698349e+00\n",
" 9.43485856e-01 2.34856772e+00]\n",
" [-1.20572495e+00 -2.03134632e+00 4.88817632e-01 ... 2.20770907e+00\n",
" 1.38143206e+00 2.00714707e+00]]\n",
"\n",
" [[ 9.00486887e-01 -9.50459957e-01 -1.42935121e+00 ... -1.30648065e+00\n",
" -2.52133775e+00 -8.87715697e-01]\n",
" [ 3.73431134e+00 1.69571114e+00 5.99429727e-01 ... 6.64332986e-01\n",
" -6.10453069e-01 2.06534386e+00]\n",
" [ 1.59800696e+00 -4.59622175e-01 -6.73136234e-01 ... 2.18770742e-01\n",
" -1.12928271e+00 4.87097502e-02]\n",
" ...\n",
" [ 1.92336845e+00 1.37130380e-01 -3.51048648e-01 ... 5.41638851e-01\n",
" 1.06069386e+00 1.36404145e+00]\n",
" [ 1.29641414e+00 -2.79530913e-01 -2.63607264e-01 ... -8.62445176e-01\n",
" 1.48393130e+00 2.69196725e+00]\n",
" [ 1.14442182e+00 -1.24098969e+00 3.70959163e-01 ... -1.12241995e+00\n",
" 3.67927134e-01 2.55976987e+00]]\n",
"\n",
" ...\n",
"\n",
" [[ 5.32017851e+00 3.64207411e+00 3.84571218e+00 ... 3.60754800e+00\n",
" 2.57500267e+00 -1.38083458e-01]\n",
" [ 5.69058084e+00 3.93056583e+00 2.93337941e+00 ... 3.17091584e+00\n",
" 2.34770632e+00 6.48133337e-01]\n",
" [ 5.98239613e+00 6.16548634e+00 3.04750896e+00 ... 5.51510525e+00\n",
" 4.34810448e+00 1.31588542e+00]\n",
" ...\n",
" [ 5.09930992e+00 3.32360983e+00 2.29228449e+00 ... 3.45123887e-01\n",
" 1.06280947e+00 -5.93325794e-02]\n",
" [ 4.19760656e+00 3.97779059e+00 1.66905916e+00 ... 3.68937254e-01\n",
" 8.06131065e-02 8.08142900e-01]\n",
" [ 4.52498960e+00 3.45109749e+00 1.01074433e+00 ... -2.54036248e-01\n",
" 3.13675582e-01 2.13851762e+00]]\n",
"\n",
" [[ 6.93927193e+00 6.05758238e+00 4.60648441e+00 ... 4.32221603e+00\n",
" 3.17874146e+00 1.47012353e+00]\n",
" [ 7.88523865e+00 6.62228966e+00 4.77496338e+00 ... 4.45868683e+00\n",
" 2.73698759e+00 2.17057824e+00]\n",
" [ 7.12061214e+00 6.01714134e+00 4.52996492e+00 ... 3.97184372e+00\n",
" 3.43153954e+00 1.21802723e+00]\n",
" ...\n",
" [ 2.85720730e+00 1.89639473e+00 1.96340394e+00 ... 1.89643729e+00\n",
" 1.64856291e+00 1.15853786e+00]\n",
" [ 3.88248491e+00 2.16386199e+00 1.53069091e+00 ... 2.71704245e+00\n",
" 2.24890351e+00 2.22156644e+00]\n",
" [ 5.27136230e+00 1.68400204e+00 2.09500480e+00 ... 2.75956345e+00\n",
" 3.71970820e+00 1.69852686e+00]]\n",
"\n",
" [[ 2.55598164e+00 1.64588141e+00 6.70431674e-01 ... 3.24091220e+00\n",
" 1.48759770e+00 -1.72001183e+00]\n",
" [ 4.33942318e+00 8.40826690e-01 -7.40000725e-01 ... 7.24577069e-01\n",
" 1.74327165e-01 -1.83029580e+00]\n",
" [ 4.39864540e+00 2.28395438e+00 -1.90353513e-01 ... 5.58019161e+00\n",
" 1.05627227e+00 -8.02519619e-01]\n",
" ...\n",
" [ 1.97654784e+00 3.26888156e+00 1.52879453e+00 ... 3.15013933e+00\n",
" 4.66731453e+00 4.98701715e+00]\n",
" [ 1.40016854e+00 3.45761251e+00 3.68359756e+00 ... 1.14207900e+00\n",
" 3.32219076e+00 3.83035636e+00]\n",
" [ 1.99269783e+00 2.15428829e+00 3.35396528e-01 ... 2.45916694e-01\n",
" 2.13785577e+00 4.33214951e+00]]]\n",
"\n",
"\n",
" [[[ 1.35320330e+00 5.05850911e-02 1.04915988e+00 ... 1.82023585e-01\n",
" 2.72914767e-01 3.92112255e-01]\n",
" [ 1.04646444e+00 7.60913491e-01 1.93323612e+00 ... 1.19493449e+00\n",
" -1.44200325e-01 4.07531261e-02]\n",
" [-9.88207340e-01 -1.46165287e+00 1.05884135e-01 ... -3.23057353e-01\n",
" -2.28934169e+00 -7.38609374e-01]\n",
" ...\n",
" [ 1.01198792e+00 2.34331083e+00 1.04566610e+00 ... 1.29697472e-01\n",
" -1.23878837e+00 2.21006930e-01]\n",
" [-3.75360101e-01 1.53673506e+00 -1.32206869e+00 ... -2.55255580e-01\n",
" -6.22699618e-01 -1.73162484e+00]\n",
" [ 4.34735864e-01 5.08327007e-01 -3.49233925e-01 ... -1.04749084e+00\n",
" -1.15777385e+00 -1.13671994e+00]]\n",
"\n",
" [[ 1.67839336e+00 -1.80224836e-01 1.02194118e+00 ... 8.44027162e-01\n",
" 8.81283879e-02 -1.37762165e+00]\n",
" [ 8.39694083e-01 1.32322550e+00 4.02442753e-01 ... -4.21785116e-01\n",
" -9.98012185e-01 -1.11348581e+00]\n",
" [ 7.64424682e-01 8.58965695e-01 2.94626594e-01 ... -6.65519595e-01\n",
" -3.65677416e-01 -2.25250268e+00]\n",
" ...\n",
" [-1.10193872e+00 1.18070498e-01 1.04604781e-01 ... -1.44486964e+00\n",
" -2.52748466e+00 -2.16131711e+00]\n",
" [-1.06079710e+00 -1.48379254e+00 3.80138367e-01 ... -1.62288392e+00\n",
" -2.44736362e+00 -8.78590107e-01]\n",
" [ 3.44401300e-02 -2.60935068e+00 -2.35597759e-01 ... -2.41114974e+00\n",
" -2.45255780e+00 -1.82384634e+00]]\n",
"\n",
" [[ 1.37670958e+00 1.58661580e+00 -2.85664916e-01 ... 1.49081087e+00\n",
" 4.13422853e-01 1.12761199e+00]\n",
" [ 1.54148173e+00 6.22704089e-01 1.41886568e+00 ... 1.59678531e+00\n",
" -8.72656107e-01 1.52415514e-01]\n",
" [ 3.30207205e+00 2.89925170e+00 1.91855145e+00 ... 3.18863559e+00\n",
" 1.87347198e+00 9.48901057e-01]\n",
" ...\n",
" [-1.53920484e+00 1.77375078e-02 -1.02018684e-01 ... 1.94011092e+00\n",
" -6.83587790e-01 1.49154460e+00]\n",
" [-2.27719522e+00 1.02481163e+00 -2.11300224e-01 ... -8.18020821e-01\n",
" 1.54248989e+00 -1.46732473e+00]\n",
" [-4.50206220e-01 3.62383485e+00 1.07175660e+00 ... 4.25961137e-01\n",
" 1.12405360e-01 -6.87821358e-02]]\n",
"\n",
" ...\n",
"\n",
" [[-3.40477467e-01 -2.99311423e+00 -2.12096786e+00 ... 2.27393007e+00\n",
" 4.03424358e+00 3.73335361e+00]\n",
" [-6.99971199e-01 -2.97719741e+00 -2.72910309e+00 ... 1.50101089e+00\n",
" 2.29408574e+00 3.14105940e+00]\n",
" [-1.41648722e+00 -1.86292887e+00 -1.84006739e+00 ... 2.78402638e+00\n",
" 3.91481900e+00 5.32456112e+00]\n",
" ...\n",
" [ 5.97958088e-01 1.50512588e+00 6.23718500e-01 ... 2.83813477e+00\n",
" 3.87909842e+00 3.33359623e+00]\n",
" [ 1.65542316e+00 3.56163192e+00 4.01527691e+00 ... 3.38367462e+00\n",
" 1.55827272e+00 2.50741863e+00]\n",
" [ 2.82036042e+00 2.53322673e+00 4.38798475e+00 ... 4.64642382e+00\n",
" 3.28739667e+00 3.02895570e+00]]\n",
"\n",
" [[-3.47941303e+00 -3.49006844e+00 -2.25583363e+00 ... 1.45181656e-01\n",
" 1.52944064e+00 2.08810711e+00]\n",
" [-2.27786446e+00 -4.59218550e+00 -2.74722624e+00 ... -1.73136210e+00\n",
" 7.46028006e-01 1.74789345e+00]\n",
" [-3.35524082e+00 -4.58244705e+00 -2.40820456e+00 ... -5.04051924e-01\n",
" 1.49640536e+00 2.16613841e+00]\n",
" ...\n",
" [ 5.26107132e-01 2.05329061e+00 2.84252572e+00 ... 1.33222675e+00\n",
" 3.87935114e+00 3.69385266e+00]\n",
" [ 4.38092083e-01 2.15028906e+00 3.13363624e+00 ... 3.36048746e+00\n",
" 5.36551809e+00 2.94915986e+00]\n",
" [ 2.75497317e+00 3.25929213e+00 2.33522987e+00 ... 1.69926262e+00\n",
" 3.93462896e+00 3.68200874e+00]]\n",
"\n",
" [[ 1.10951948e+00 5.31419516e-02 -1.58864903e+00 ... 5.24887085e+00\n",
" 1.60273385e+00 4.90113163e+00]\n",
" [-2.94517064e+00 -2.81092644e+00 -4.89631557e+00 ... 3.99868512e+00\n",
" 1.40544355e+00 2.84833241e+00]\n",
" [-3.51893663e-01 -3.53325534e+00 -2.21239805e+00 ... 4.26225853e+00\n",
" 6.87886119e-01 2.58609629e+00]\n",
" ...\n",
" [ 2.92248201e+00 5.40264511e+00 4.65721560e+00 ... 5.24537373e+00\n",
" 2.30406880e+00 1.29892707e+00]\n",
" [ 1.43473256e+00 4.61167526e+00 3.57578802e+00 ... 5.12181854e+00\n",
" 8.59923482e-01 1.38731599e+00]\n",
" [-6.50881350e-01 2.18233657e+00 2.74669623e+00 ... 4.86368895e+00\n",
" 1.44120216e+00 1.79993320e+00]]]\n",
"\n",
"\n",
" [[[ 1.64106202e+00 3.54410499e-01 -3.54172409e-01 ... 2.32646990e+00\n",
" 1.65043330e+00 3.45897645e-01]\n",
" [ 2.16236949e+00 1.28213906e+00 2.26082468e+00 ... 6.10507369e-01\n",
" 9.12241280e-01 1.27429694e-01]\n",
" [ 2.07962990e+00 7.03816175e-01 2.01272345e+00 ... -2.26959705e-01\n",
" 1.00041127e+00 5.87104559e-02]\n",
" ...\n",
" [-1.62972426e+00 -3.04028845e+00 -1.39124167e+00 ... 2.47561097e+00\n",
" 2.35047388e+00 1.61532843e+00]\n",
" [-1.97368932e+00 -5.44541061e-01 -5.92882216e-01 ... 1.39800012e+00\n",
" 2.32770801e+00 9.96662021e-01]\n",
" [-1.15636075e+00 -1.34654212e+00 -8.50648999e-01 ... 1.85655832e+00\n",
" 2.05776072e+00 5.34575820e-01]]\n",
"\n",
" [[-1.02104437e+00 3.08469892e-01 2.81789303e-01 ... -8.24654043e-01\n",
" -9.85817850e-01 -2.05517030e+00]\n",
" [ 9.50192690e-01 3.35105330e-01 5.31637192e-01 ... -1.42974198e-01\n",
" -1.79659498e+00 -1.58266973e+00]\n",
" [-2.51316994e-01 -1.28709340e+00 3.01498562e-01 ... -1.32253516e+00\n",
" -1.55507576e+00 -9.37123299e-01]\n",
" ...\n",
" [ 2.33016998e-01 2.92454743e+00 3.15420461e+00 ... 1.15574491e+00\n",
" 1.27850962e+00 1.35487700e+00]\n",
" [ 3.81013602e-01 1.44239831e+00 6.64825320e-01 ... -3.89374971e-01\n",
" 1.50716826e-01 1.33641326e+00]\n",
" [ 1.71373415e+00 1.67357373e+00 1.76596940e+00 ... 1.57941079e+00\n",
" 1.60940981e+00 1.78091609e+00]]\n",
"\n",
" [[-5.16522598e+00 -1.68099070e+00 -3.24440050e+00 ... -3.46229005e+00\n",
" -2.18273020e+00 -1.98621082e+00]\n",
" [-3.05743694e+00 9.15392339e-01 -1.93508530e+00 ... -1.82306373e+00\n",
" -2.12960863e+00 -3.45255351e+00]\n",
" [-4.32777822e-01 -1.00303245e+00 -1.61397791e+00 ... -2.08376765e+00\n",
" -3.72989595e-01 -1.36516929e+00]\n",
" ...\n",
" [-5.83641946e-01 4.14125490e+00 1.58227599e+00 ... 2.03144050e+00\n",
" 2.13982654e+00 -1.81909311e+00]\n",
" [-1.74230576e+00 2.39347410e+00 2.44080925e+00 ... 5.43732524e-01\n",
" 2.07899213e+00 -3.71748984e-01]\n",
" [ 3.80016506e-01 7.84988403e-01 1.20596504e+00 ... -2.32057095e+00\n",
" -2.81265080e-01 -3.69353056e+00]]\n",
"\n",
" ...\n",
"\n",
" [[-3.48024845e+00 -2.60937548e+00 -3.84952760e+00 ... 6.68736577e-01\n",
" -1.75104141e-02 -3.54720926e+00]\n",
" [-2.59637117e+00 -5.18190145e+00 -2.33887696e+00 ... 9.13373232e-02\n",
" -3.58282638e+00 -2.40778995e+00]\n",
" [-2.50912881e+00 -1.22113395e+00 -2.34372020e+00 ... 1.40071487e+00\n",
" -1.67449510e+00 -1.14655948e+00]\n",
" ...\n",
" [-5.75253534e+00 -6.67348385e+00 -5.05184650e+00 ... -2.73145151e+00\n",
" -1.48933101e+00 -1.36807609e+00]\n",
" [-3.29049587e+00 -3.73956156e+00 -2.85064268e+00 ... -3.92481357e-01\n",
" -8.00529659e-01 -8.39800835e-01]\n",
" [-4.30351114e+00 -4.21471930e+00 -2.41703367e+00 ... -1.27081513e+00\n",
" 1.67839837e+00 8.47821474e-01]]\n",
"\n",
" [[-5.27856112e-01 -1.09752083e+00 3.39107156e-01 ... 2.00062895e+00\n",
" 8.83528054e-01 2.57416844e-01]\n",
" [-1.58655810e+00 -3.36268663e-01 1.16161990e+00 ... 1.54868484e+00\n",
" 2.38878536e+00 1.84097290e+00]\n",
" [ 5.96052647e-01 2.15484858e-01 1.85280466e+00 ... 2.74587560e+00\n",
" 1.61432290e+00 1.13214278e+00]\n",
" ...\n",
" [-4.57659864e+00 -5.42679739e+00 -4.35204458e+00 ... -1.82452416e+00\n",
" -2.18670201e+00 -3.91811800e+00]\n",
" [-1.32477629e+00 -4.19110394e+00 -3.41308069e+00 ... 1.39622003e-01\n",
" -1.59393203e+00 -9.08105671e-01]\n",
" [-3.60161018e+00 -4.05932713e+00 -2.23674798e+00 ... 9.09647286e-01\n",
" 9.73127842e-01 1.19991803e+00]]\n",
"\n",
" [[ 2.04062796e+00 7.95603275e-01 -1.28833270e+00 ... 4.64749050e+00\n",
" 2.25974560e+00 1.02396965e+00]\n",
" [ 1.68882537e+00 2.63353348e+00 2.53597498e-02 ... 4.69063854e+00\n",
" -4.19382691e-01 2.91669458e-01]\n",
" [ 7.71395087e-01 1.20833695e+00 -2.58601785e-01 ... 1.21794045e+00\n",
" -1.51922226e-01 7.44265199e-01]\n",
" ...\n",
" [-6.66095781e+00 -4.81577682e+00 -5.39921665e+00 ... -2.20548606e+00\n",
" 5.72486281e-01 -4.35207397e-01]\n",
" [-7.51608658e+00 -6.67776871e+00 -3.73199415e+00 ... -1.70327055e+00\n",
" 1.01334639e-02 -3.20627165e+00]\n",
" [-5.73050356e+00 -2.74379373e+00 -3.70248461e+00 ... -1.09794116e+00\n",
" -1.73590891e-02 -1.80156028e+00]]]]\n",
"param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: [-1.4305115e-06 0.0000000e+00 -4.0531158e-06 -1.6689301e-06\n",
" 2.3841858e-07 -7.1525574e-07 1.1920929e-06 1.5497208e-06\n",
" -2.3841858e-07 1.6689301e-06 9.5367432e-07 9.5367432e-07\n",
" -2.6226044e-06 1.1920929e-06 1.3113022e-06 1.9669533e-06\n",
" -4.7683716e-07 1.1920929e-06 -1.6689301e-06 -1.5497208e-06\n",
" -2.2649765e-06 4.7683716e-07 2.3841858e-06 -3.5762787e-06\n",
" 2.3841858e-07 2.1457672e-06 -3.5762787e-07 8.3446503e-07\n",
" -3.5762787e-07 -7.1525574e-07 2.6524067e-06 -1.1920929e-06]\n",
"param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: [-3.7669735 1.5226867 1.759756 4.501629 -2.2077336 0.18411277\n",
" 1.3558264 -1.0269645 3.9628277 3.9300344 -2.80754 1.8462183\n",
" -0.03385968 2.1284049 0.46124816 -4.364863 0.78491163 0.25565645\n",
" -5.3538237 3.2606194 0.79100513 -1.4652673 2.769378 1.2283417\n",
" -4.7466464 -1.3404545 -6.9374166 0.710248 2.0944448 0.4334769\n",
" -0.24313992 0.31392363]\n",
"param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: [-0.6251638 2.833331 0.6993131 3.7106915 -2.262496 0.7390424\n",
" 0.5360477 -2.803875 2.1646228 2.117193 -1.9988279 1.5135905\n",
" -2.0181084 2.6450465 0.06302822 -3.0530102 1.4788482 0.5941844\n",
" -3.1690063 1.8753575 -0.0737313 -2.7806277 -0.04483938 0.16129279\n",
" -1.2960215 -0.38020235 -0.55218065 0.10754502 2.065371 -1.4703183\n",
" -0.40964937 -1.4454535 ]\n",
"param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n",
"param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: [[-0.46178514 0.1095643 0.06441769 ... 0.42020613 -0.34181893\n",
" -0.0658682 ]\n",
" [-0.03619978 0.21653323 0.01727325 ... 0.05731536 -0.37822944\n",
" -0.05464617]\n",
" [-0.32397318 0.04158126 -0.08091418 ... 0.0928297 -0.06518176\n",
" -0.40110156]\n",
" ...\n",
" [-0.2702023 0.05126935 0.11825457 ... 0.0069707 -0.36951366\n",
" 0.37071258]\n",
" [-0.11326203 0.19305304 -0.133317 ... -0.13030824 -0.09068564\n",
" 0.32735693]\n",
" [-0.04543798 0.09902512 -0.10745425 ... -0.06685166 -0.3055201\n",
" 0.0752247 ]]\n",
"param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.07338604 0.64991236 0.5465856 ... 0.507725 0.14061031\n",
" 0.3020359 ]\n",
"param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.41395143 -0.28493872 0.36796764 ... 0.2387953 0.06732331\n",
" 0.16263628]\n",
"param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.09370177 -0.12264141 -0.08237482 ... -0.50241685 -0.149155\n",
" -0.25661892]\n",
" [-0.37426725 0.44987115 0.10685667 ... -0.65946174 -0.4499248\n",
" -0.17545304]\n",
" [-0.03753807 0.33422717 0.12750985 ... 0.05405155 -0.17648363\n",
" 0.05315325]\n",
" ...\n",
" [ 0.15721183 0.03064088 -0.00751081 ... 0.27183983 0.3881693\n",
" -0.01544908]\n",
" [ 0.26047793 0.16917065 0.00915196 ... 0.18076143 -0.05080506\n",
" 0.14791614]\n",
" [ 0.19052255 0.03642382 -0.14313167 ... 0.2611448 0.20763844\n",
" 0.26846847]]\n",
"param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.4139514 -0.28493875 0.36796758 ... 0.23879525 0.06732336\n",
" 0.16263627]\n",
"param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[ 0.04214853 -0.1710323 0.17557406 ... 0.11926915 0.21577051\n",
" -0.30598596]\n",
" [-0.02370887 -0.03498494 -0.05991999 ... -0.06049232 -0.14527473\n",
" -0.5335691 ]\n",
" [-0.21417995 -0.10263194 -0.05903128 ... -0.26958284 0.05936668\n",
" 0.25522667]\n",
" ...\n",
" [ 0.31594425 -0.29487017 0.15871571 ... 0.3504135 -0.1418606\n",
" -0.07482046]\n",
" [ 0.22316164 0.7682122 -0.22191924 ... -0.00535548 -0.6497105\n",
" -0.2011079 ]\n",
" [-0.05800886 0.13750821 0.02450509 ... 0.245736 0.07425706\n",
" -0.17761081]]\n",
"param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.45080703 0.19005743 0.077441 ... -0.24504453 0.19666554\n",
" -0.10503208]\n",
"param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237206 0.03389215 ... -0.35602498 0.25528812\n",
" 0.11344345]\n",
"param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.48457903 0.04466334 -0.19785863 ... -0.0254025 -0.10338341\n",
" -0.29202533]\n",
" [-0.15261276 0.00412052 0.22198747 ... 0.22460426 -0.03752084\n",
" 0.05170784]\n",
" [-0.09337254 0.02530848 0.1263681 ... -0.02056236 0.33342454\n",
" -0.08760723]\n",
" ...\n",
" [-0.28645608 -0.19169135 -0.1361257 ... -0.00444204 -0.06552711\n",
" -0.14726155]\n",
" [ 0.21883707 0.2049045 0.23723911 ... 0.4626113 -0.14110637\n",
" 0.02569831]\n",
" [ 0.37554163 -0.19249167 0.14591683 ... 0.25602737 0.40088275\n",
" 0.41056633]]\n",
"param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237211 0.0338921 ... -0.35602498 0.2552881\n",
" 0.11344352]\n",
"param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[-0.28007814 -0.09206 -0.01297755 ... -0.2557205 -0.2693453\n",
" 0.05862035]\n",
" [-0.34194735 -0.01383794 -0.06490533 ... -0.11063005 0.16226721\n",
" -0.3197178 ]\n",
" [-0.3646778 0.15443833 0.02241019 ... -0.15093157 -0.09886418\n",
" -0.44295847]\n",
" ...\n",
" [-0.01041886 -0.57636976 -0.03988511 ... -0.2260822 0.49646813\n",
" -0.15528557]\n",
" [-0.19385241 -0.56451964 -0.05551083 ... -0.5638106 0.43611372\n",
" -0.61484563]\n",
" [ 0.1051331 -0.4762463 0.11194798 ... -0.26766616 -0.30734932\n",
" 0.17856634]]\n",
"param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.02791309 -0.992517 0.63012564 ... -1.1830902 1.4646478\n",
" 1.6333911 ]\n",
"param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.10834587 -1.7079136 0.81259465 ... -1.4478713 1.455745\n",
" 2.069446 ]\n",
"param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n",
"param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.14363798 -0.06933184 0.02901152 ... -0.19233373 -0.03206367\n",
" -0.00845779]\n",
" [-0.44314507 -0.8921327 -1.031872 ... -0.558997 -0.53070104\n",
" -0.855925 ]\n",
" [ 0.15673254 0.28793585 0.13351494 ... 0.38433537 0.5040767\n",
" 0.11303265]\n",
" ...\n",
" [-0.22923109 -0.62508404 -0.6195032 ... -0.6876448 -0.41718128\n",
" -0.74844164]\n",
" [ 0.18024652 0.45618314 0.81391454 ... 0.5780604 0.87566674\n",
" 0.71526295]\n",
" [ 0.3763076 0.54033077 0.9940485 ... 1.087821 0.72288674\n",
" 1.2852117 ]]\n",
"param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.10834593 -1.7079139 0.8125948 ... -1.4478711 1.4557447\n",
" 2.0694466 ]\n",
"param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n",
"param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n",
"param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: [[ 1.4382483e-02 2.0160766e-02 1.2322801e-02 ... 1.0075266e-02\n",
" 7.4421698e-03 -2.3925617e+01]\n",
" [ 3.7887424e-02 5.7105277e-02 2.8803380e-02 ... 2.4820438e-02\n",
" 1.8560058e-02 -5.0687141e+01]\n",
" [ 4.5566272e-02 5.4415584e-02 3.2858539e-02 ... 3.2725763e-02\n",
" 2.1536341e-02 -6.1036335e+01]\n",
" ...\n",
" [ 2.8015019e-02 3.5967816e-02 2.3228688e-02 ... 2.1284629e-02\n",
" 1.3860047e-02 -5.2543671e+01]\n",
" [ 2.8445240e-02 4.2448867e-02 2.7125146e-02 ... 2.2253662e-02\n",
" 1.7470375e-02 -4.3619675e+01]\n",
" [ 4.7438074e-02 5.8287360e-02 3.4546286e-02 ... 3.0827176e-02\n",
" 2.2168703e-02 -6.7901680e+01]]\n",
"param grad: fc.bias: shape: [4299] stop_grad: False grad: [ 8.8967547e-02 1.0697905e-01 6.5251388e-02 ... 6.1503030e-02\n",
" 4.3404289e-02 -1.3512518e+02]\n"
]
}
],
"source": [
"loss.backward(retain_graph=False)\n",
"for n, p in dp_model.named_parameters():\n",
" print(\n",
" f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "selected-crazy",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1.]\n"
]
}
],
"source": [
"print(loss.grad)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bottom-engineer",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "stuffed-yeast",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}