{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cloudy-glass", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISISBLE_DEVICES'] = '0'" ] }, { "cell_type": "code", "execution_count": 2, "id": "grand-stephen", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " def convert_to_list(value, n, name, dtype=np.int):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2.0.0\n" ] } ], "source": [ "import paddle\n", "print(paddle.__version__)" ] }, { "cell_type": "code", "execution_count": null, "id": "isolated-prize", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "id": "romance-samuel", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n" ] } ], "source": [ "import sys\n", "import argparse\n", "import functools\n", "from utils.utility import add_arguments, print_arguments\n", "parser = argparse.ArgumentParser(description=__doc__)\n", "add_arg = functools.partial(add_arguments, argparser=parser)\n", "# yapf: disable\n", "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", "add_arg('beam_size', int, 500, \"Beam search width.\")\n", "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", " \"bi-directional RNNs. Not for GRU.\")\n", "add_arg('infer_manifest', str,\n", " 'examples/aishell/data/manifest.dev',\n", " \"Filepath of manifest to infer.\")\n", "add_arg('mean_std_path', str,\n", " 'examples/aishell/data/mean_std.npz',\n", " \"Filepath of normalizer's mean & std.\")\n", "add_arg('vocab_path', str,\n", " 'examples/aishell/data/vocab.txt',\n", " \"Filepath of vocabulary.\")\n", "add_arg('lang_model_path', str,\n", " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", " \"Filepath for language model.\")\n", "add_arg('model_path', str,\n", " 'examples/aishell/checkpoints/step_final',\n", " \"If None, the training starts from scratch, \"\n", " \"otherwise, it resumes from the pre-trained model.\")\n", "add_arg('decoding_method', str,\n", " 'ctc_beam_search',\n", " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", " choices = ['ctc_beam_search', 'ctc_greedy'])\n", "add_arg('error_rate_type', str,\n", " 'wer',\n", " \"Error rate type for evaluation.\",\n", " choices=['wer', 'cer'])\n", "add_arg('specgram_type', str,\n", " 'linear',\n", " \"Audio feature type. Options: linear, mfcc.\",\n", " choices=['linear', 'mfcc'])\n", "# yapf: disable\n", "args = parser.parse_args([])\n", "print(vars(args))" ] }, { "cell_type": "code", "execution_count": 4, "id": "timely-bikini", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", " from numpy.dual import register_func\n", "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " long_ = _make_signed(np.long)\n", "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " ulong = _make_unsigned(np.long)\n" ] } ], "source": [ "from data_utils.dataset import create_dataloader\n", "batch_reader = create_dataloader(\n", " manifest_path=args.infer_manifest,\n", " vocab_filepath=args.vocab_path,\n", " mean_std_filepath=args.mean_std_path,\n", " augmentation_config='{}',\n", " #max_duration=float('inf'),\n", " max_duration=27.0,\n", " min_duration=0.0,\n", " stride_ms=10.0,\n", " window_ms=20.0,\n", " max_freq=None,\n", " specgram_type=args.specgram_type,\n", " use_dB_normalization=True,\n", " random_seed=0,\n", " keep_transcription_text=False,\n", " is_training=False,\n", " batch_size=args.num_samples,\n", " sortagrad=True,\n", " shuffle_method=None,\n", " dist=False)" ] }, { "cell_type": "code", "execution_count": 5, "id": "organized-warrior", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n", "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " if arr.dtype == np.object:\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", " [[14 , 34 , 322 , 233 , 0 , 0 ],\n", " [238 , 38 , 122 , 164 , 0 , 0 ],\n", " [8 , 52 , 49 , 42 , 0 , 0 ],\n", " [109 , 47 , 146 , 193 , 210 , 479 ],\n", " [3330, 1751, 208 , 1923, 0 , 0 ]])\n", "test raw 大时代里的的\n", "test raw 煲汤受宠的的\n", "audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", " [163, 167, 180, 186, 186])\n", "test len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", " [4, 4, 4, 6, 4])\n", "audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", " [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n", " [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n", " [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n", " ...,\n", " [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n", " [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n", " [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n", "\n", " [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n", " [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n", " [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n", " ...,\n", " [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n", " [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n", " [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n", "\n", " [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n", " [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n", " [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n", " ...,\n", " [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n", " [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n", " [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n", "\n", " [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n", " [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n", " [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n", " ...,\n", " [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n", " [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n", " [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n", "\n", " [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n", " [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n", " [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n", " ...,\n", " [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n", " [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n", " [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n" ] } ], "source": [ " for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", " print('test', text)\n", " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n", " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n", " print('audio len', audio_len)\n", " print('test len', text_len)\n", " print('audio', audio)\n", " break" ] }, { "cell_type": "code", "execution_count": 6, "id": "confidential-radius", "metadata": {}, "outputs": [], "source": [ "# reader = batch_reader()\n", "# audio, test , audio_len, text_len = reader.next()\n", "# print('test', text)\n", "# print('t len', text_len) #[B, T]\n", "# print('audio len', audio_len)\n", "# print(audio)" ] }, { "cell_type": "code", "execution_count": 7, "id": "future-vermont", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "煲汤受宠\n" ] } ], "source": [ "print(u'\\u7172\\u6c64\\u53d7\\u5ba0')" ] }, { "cell_type": "code", "execution_count": null, "id": "dental-sweden", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "sunrise-contact", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "hispanic-asthma", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "hearing-leadership", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "skilled-friday", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "copyrighted-measure", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 21, "id": "employed-lightweight", "metadata": {}, "outputs": [], "source": [ "from model_utils.network import DeepSpeech2, DeepSpeech2Loss\n", "\n", "from data_utils.dataset import create_dataloader\n", "batch_reader = create_dataloader(\n", " manifest_path=args.infer_manifest,\n", " vocab_filepath=args.vocab_path,\n", " mean_std_filepath=args.mean_std_path,\n", " augmentation_config='{}',\n", " #max_duration=float('inf'),\n", " max_duration=27.0,\n", " min_duration=0.0,\n", " stride_ms=10.0,\n", " window_ms=20.0,\n", " max_freq=None,\n", " specgram_type=args.specgram_type,\n", " use_dB_normalization=True,\n", " random_seed=0,\n", " keep_transcription_text=False,\n", " is_training=False,\n", " batch_size=args.num_samples,\n", " sortagrad=True,\n", " shuffle_method=None,\n", " dist=False)\n", "\n", "\n", "import paddle\n", "from paddle import nn\n", "from paddle.nn import functional as F\n", "from paddle.nn import initializer as I\n", "\n", "import math\n", "\n", "def brelu(x, t_min=0.0, t_max=24.0, name=None):\n", " t_min = paddle.to_tensor(t_min)\n", " t_max = paddle.to_tensor(t_max)\n", " return x.maximum(t_min).minimum(t_max)\n", "\n", "def sequence_mask(x_len, max_len=None, dtype='float32'):\n", " max_len = max_len or x_len.max()\n", " x_len = paddle.unsqueeze(x_len, -1)\n", " row_vector = paddle.arange(max_len)\n", " mask = row_vector > x_len # maybe a bug\n", " mask = paddle.cast(mask, dtype)\n", " print(f'seq mask: {mask}')\n", " return mask\n", "\n", "\n", "class ConvBn(nn.Layer):\n", " def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,\n", " padding, act):\n", "\n", " super().__init__()\n", " self.kernel_size = kernel_size\n", " self.stride = stride\n", " self.padding = padding\n", "\n", " self.conv = nn.Conv2D(\n", " num_channels_in,\n", " num_channels_out,\n", " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding,\n", " weight_attr=None,\n", " bias_attr=None,\n", " data_format='NCHW')\n", "\n", " self.bn = nn.BatchNorm2D(\n", " num_channels_out,\n", " weight_attr=None,\n", " bias_attr=None,\n", " data_format='NCHW')\n", " self.act = F.relu if act == 'relu' else brelu\n", "\n", " def forward(self, x, x_len):\n", " \"\"\"\n", " x(Tensor): audio, shape [B, C, D, T]\n", " \"\"\"\n", " x = self.conv(x)\n", " x = self.bn(x)\n", " x = self.act(x)\n", "\n", " x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]\n", " ) // self.stride[1] + 1\n", "\n", " # reset padding part to 0\n", " masks = sequence_mask(x_len) #[B, T]\n", " masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]\n", " x = x.multiply(masks)\n", "\n", " return x, x_len\n", "\n", "\n", "class ConvStack(nn.Layer):\n", " def __init__(self, feat_size, num_stacks):\n", " super().__init__()\n", " self.feat_size = feat_size # D\n", " self.num_stacks = num_stacks\n", "\n", " self.conv_in = ConvBn(\n", " num_channels_in=1,\n", " num_channels_out=32,\n", " kernel_size=(41, 11), #[D, T]\n", " stride=(2, 3),\n", " padding=(20, 5),\n", " act='brelu')\n", "\n", " out_channel = 32\n", " self.conv_stack = nn.Sequential([\n", " ConvBn(\n", " num_channels_in=32,\n", " num_channels_out=out_channel,\n", " kernel_size=(21, 11),\n", " stride=(2, 1),\n", " padding=(10, 5),\n", " act='brelu') for i in range(num_stacks - 1)\n", " ])\n", "\n", " # conv output feat_dim\n", " output_height = (feat_size - 1) // 2 + 1\n", " for i in range(self.num_stacks - 1):\n", " output_height = (output_height - 1) // 2 + 1\n", " self.output_height = out_channel * output_height\n", "\n", " def forward(self, x, x_len):\n", " \"\"\"\n", " x: shape [B, C, D, T]\n", " x_len : shape [B]\n", " \"\"\"\n", " print(f\"conv in: {x_len}\")\n", " x, x_len = self.conv_in(x, x_len)\n", " for i, conv in enumerate(self.conv_stack):\n", " print(f\"conv in: {x_len}\")\n", " x, x_len = conv(x, x_len)\n", " print(f\"conv out: {x_len}\")\n", " return x, x_len\n", " \n", " \n", "\n", "class RNNCell(nn.RNNCellBase):\n", " r\"\"\"\n", " Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it \n", " computes the outputs and updates states.\n", " The formula used is as follows:\n", " .. math::\n", " h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})\n", " y_{t} & = h_{t}\n", " \n", " where :math:`act` is for :attr:`activation`.\n", " \"\"\"\n", "\n", " def __init__(self,\n", " hidden_size,\n", " activation=\"tanh\",\n", " weight_ih_attr=None,\n", " weight_hh_attr=None,\n", " bias_ih_attr=None,\n", " bias_hh_attr=None,\n", " name=None):\n", " super().__init__()\n", " std = 1.0 / math.sqrt(hidden_size)\n", " self.weight_hh = self.create_parameter(\n", " (hidden_size, hidden_size),\n", " weight_hh_attr,\n", " default_initializer=I.Uniform(-std, std))\n", " # self.bias_ih = self.create_parameter(\n", " # (hidden_size, ),\n", " # bias_ih_attr,\n", " # is_bias=True,\n", " # default_initializer=I.Uniform(-std, std))\n", " self.bias_ih = None\n", " self.bias_hh = self.create_parameter(\n", " (hidden_size, ),\n", " bias_hh_attr,\n", " is_bias=True,\n", " default_initializer=I.Uniform(-std, std))\n", "\n", " self.hidden_size = hidden_size\n", " if activation not in [\"tanh\", \"relu\", \"brelu\"]:\n", " raise ValueError(\n", " \"activation for SimpleRNNCell should be tanh or relu, \"\n", " \"but get {}\".format(activation))\n", " self.activation = activation\n", " self._activation_fn = paddle.tanh \\\n", " if activation == \"tanh\" \\\n", " else F.relu\n", " if activation == 'brelu':\n", " self._activation_fn = brelu\n", "\n", " def forward(self, inputs, states=None):\n", " if states is None:\n", " states = self.get_initial_states(inputs, self.state_shape)\n", " pre_h = states\n", " i2h = inputs\n", " if self.bias_ih is not None:\n", " i2h += self.bias_ih\n", " h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)\n", " if self.bias_hh is not None:\n", " h2h += self.bias_hh\n", " h = self._activation_fn(i2h + h2h)\n", " return h, h\n", "\n", " @property\n", " def state_shape(self):\n", " return (self.hidden_size, )\n", "\n", "\n", "class GRUCellShare(nn.RNNCellBase):\n", " r\"\"\"\n", " Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, \n", " it computes the outputs and updates states.\n", " The formula for GRU used is as follows:\n", " .. math::\n", " r_{t} & = \\sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})\n", " z_{t} & = \\sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})\n", " \\widetilde{h}_{t} & = \\tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))\n", " h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \\widetilde{h}_{t}\n", " y_{t} & = h_{t}\n", " \n", " where :math:`\\sigma` is the sigmoid fucntion, and * is the elemetwise \n", " multiplication operator.\n", " \"\"\"\n", "\n", " def __init__(self,\n", " input_size,\n", " hidden_size,\n", " weight_ih_attr=None,\n", " weight_hh_attr=None,\n", " bias_ih_attr=None,\n", " bias_hh_attr=None,\n", " name=None):\n", " super().__init__()\n", " std = 1.0 / math.sqrt(hidden_size)\n", " self.weight_hh = self.create_parameter(\n", " (3 * hidden_size, hidden_size),\n", " weight_hh_attr,\n", " default_initializer=I.Uniform(-std, std))\n", " # self.bias_ih = self.create_parameter(\n", " # (3 * hidden_size, ),\n", " # bias_ih_attr,\n", " # is_bias=True,\n", " # default_initializer=I.Uniform(-std, std))\n", " self.bias_ih = None\n", " self.bias_hh = self.create_parameter(\n", " (3 * hidden_size, ),\n", " bias_hh_attr,\n", " is_bias=True,\n", " default_initializer=I.Uniform(-std, std))\n", "\n", " self.hidden_size = hidden_size\n", " self.input_size = input_size\n", " self._gate_activation = F.sigmoid\n", " #self._activation = paddle.tanh\n", " self._activation = F.relu\n", "\n", " def forward(self, inputs, states=None):\n", " if states is None:\n", " states = self.get_initial_states(inputs, self.state_shape)\n", "\n", " pre_hidden = states\n", " x_gates = inputs\n", " if self.bias_ih is not None:\n", " x_gates = x_gates + self.bias_ih\n", " h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)\n", " if self.bias_hh is not None:\n", " h_gates = h_gates + self.bias_hh\n", "\n", " x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)\n", " h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)\n", "\n", " r = self._gate_activation(x_r + h_r)\n", " z = self._gate_activation(x_z + h_z)\n", " c = self._activation(x_c + r * h_c) # apply reset gate after mm\n", " h = (pre_hidden - c) * z + c\n", " # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru\n", " #h = (1-z) * pre_hidden + z * c\n", "\n", " return h, h\n", "\n", " @property\n", " def state_shape(self):\n", " r\"\"\"\n", " The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch\n", " size would be automatically inserted into shape). The shape corresponds\n", " to the shape of :math:`h_{t-1}`.\n", " \"\"\"\n", " return (self.hidden_size, )\n", "\n", "\n", "class BiRNNWithBN(nn.Layer):\n", " \"\"\"Bidirectonal simple rnn layer with sequence-wise batch normalization.\n", " The batch normalization is only performed on input-state weights.\n", "\n", " :param name: Name of the layer parameters.\n", " :type name: string\n", " :param size: Dimension of RNN cells.\n", " :type size: int\n", " :param share_weights: Whether to share input-hidden weights between\n", " forward and backward directional RNNs.\n", " :type share_weights: bool\n", " :return: Bidirectional simple rnn layer.\n", " :rtype: Variable\n", " \"\"\"\n", "\n", " def __init__(self, i_size, h_size, share_weights):\n", " super().__init__()\n", " self.share_weights = share_weights\n", " if self.share_weights:\n", " #input-hidden weights shared between bi-directional rnn.\n", " self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", " # batch norm is only performed on input-state projection\n", " self.fw_bn = nn.BatchNorm1D(\n", " h_size, bias_attr=None, data_format='NLC')\n", " self.bw_fc = self.fw_fc\n", " self.bw_bn = self.fw_bn\n", " else:\n", " self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", " self.fw_bn = nn.BatchNorm1D(\n", " h_size, bias_attr=None, data_format='NLC')\n", " self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", " self.bw_bn = nn.BatchNorm1D(\n", " h_size, bias_attr=None, data_format='NLC')\n", "\n", " self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n", " self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n", " self.fw_rnn = nn.RNN(\n", " self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n", " self.bw_rnn = nn.RNN(\n", " self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n", "\n", " def forward(self, x, x_len):\n", " # x, shape [B, T, D]\n", " fw_x = self.fw_bn(self.fw_fc(x))\n", " bw_x = self.bw_bn(self.bw_fc(x))\n", " fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n", " bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n", " x = paddle.concat([fw_x, bw_x], axis=-1)\n", " return x, x_len\n", "\n", "\n", "class BiGRUWithBN(nn.Layer):\n", " \"\"\"Bidirectonal gru layer with sequence-wise batch normalization.\n", " The batch normalization is only performed on input-state weights.\n", "\n", " :param name: Name of the layer.\n", " :type name: string\n", " :param input: Input layer.\n", " :type input: Variable\n", " :param size: Dimension of GRU cells.\n", " :type size: int\n", " :param act: Activation type.\n", " :type act: string\n", " :return: Bidirectional GRU layer.\n", " :rtype: Variable\n", " \"\"\"\n", "\n", " def __init__(self, i_size, h_size, act):\n", " super().__init__()\n", " hidden_size = h_size * 3\n", " self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n", " self.fw_bn = nn.BatchNorm1D(\n", " hidden_size, bias_attr=None, data_format='NLC')\n", " self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n", " self.bw_bn = nn.BatchNorm1D(\n", " hidden_size, bias_attr=None, data_format='NLC')\n", "\n", " self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n", " self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n", " self.fw_rnn = nn.RNN(\n", " self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n", " self.bw_rnn = nn.RNN(\n", " self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n", "\n", " def forward(self, x, x_len):\n", " # x, shape [B, T, D]\n", " fw_x = self.fw_bn(self.fw_fc(x))\n", " bw_x = self.bw_bn(self.bw_fc(x))\n", " fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n", " bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n", " x = paddle.concat([fw_x, bw_x], axis=-1)\n", " return x, x_len\n", "\n", "\n", "class RNNStack(nn.Layer):\n", " \"\"\"RNN group with stacked bidirectional simple RNN or GRU layers.\n", "\n", " :param input: Input layer.\n", " :type input: Variable\n", " :param size: Dimension of RNN cells in each layer.\n", " :type size: int\n", " :param num_stacks: Number of stacked rnn layers.\n", " :type num_stacks: int\n", " :param use_gru: Use gru if set True. Use simple rnn if set False.\n", " :type use_gru: bool\n", " :param share_rnn_weights: Whether to share input-hidden weights between\n", " forward and backward directional RNNs.\n", " It is only available when use_gru=False.\n", " :type share_weights: bool\n", " :return: Output layer of the RNN group.\n", " :rtype: Variable\n", " \"\"\"\n", "\n", " def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):\n", " super().__init__()\n", " self.rnn_stacks = nn.LayerList()\n", " for i in range(num_stacks):\n", " if use_gru:\n", " #default:GRU using tanh\n", " self.rnn_stacks.append(\n", " BiGRUWithBN(i_size=i_size, h_size=h_size, act=\"relu\"))\n", " else:\n", " self.rnn_stacks.append(\n", " BiRNNWithBN(\n", " i_size=i_size,\n", " h_size=h_size,\n", " share_weights=share_rnn_weights))\n", " i_size = h_size * 2\n", "\n", " def forward(self, x, x_len):\n", " \"\"\"\n", " x: shape [B, T, D]\n", " x_len: shpae [B]\n", " \"\"\"\n", " for i, rnn in enumerate(self.rnn_stacks):\n", " x, x_len = rnn(x, x_len)\n", " masks = sequence_mask(x_len) #[B, T]\n", " masks = masks.unsqueeze(-1) # [B, T, 1]\n", " x = x.multiply(masks)\n", " return x, x_len\n", "\n", " \n", "class DeepSpeech2Test(DeepSpeech2):\n", " def __init__(self,\n", " feat_size,\n", " dict_size,\n", " num_conv_layers=2,\n", " num_rnn_layers=3,\n", " rnn_size=256,\n", " use_gru=False,\n", " share_rnn_weights=True):\n", " super().__init__(feat_size,\n", " dict_size,\n", " num_conv_layers=2,\n", " num_rnn_layers=3,\n", " rnn_size=256,\n", " use_gru=False,\n", " share_rnn_weights=True)\n", " self.feat_size = feat_size # 161 for linear\n", " self.dict_size = dict_size\n", "\n", " self.conv = ConvStack(feat_size, num_conv_layers)\n", " \n", "# self.fc = nn.Linear(1312, dict_size + 1)\n", "\n", " i_size = self.conv.output_height # H after conv stack\n", " self.rnn = RNNStack(\n", " i_size=i_size,\n", " h_size=rnn_size,\n", " num_stacks=num_rnn_layers,\n", " use_gru=use_gru,\n", " share_rnn_weights=share_rnn_weights)\n", " \n", " self.fc = nn.Linear(rnn_size * 2, dict_size + 1)\n", " \n", " def infer(self, audio, audio_len):\n", " # [B, D, T] -> [B, C=1, D, T]\n", " audio = audio.unsqueeze(1)\n", "\n", " # convolution group\n", " x, audio_len = self.conv(audio, audio_len)\n", " print('conv out', x.shape)\n", "\n", " # convert data from convolution feature map to sequence of vectors\n", " B, C, D, T = paddle.shape(x)\n", " x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]\n", " x = x.reshape([B, T, C * D]) #[B, T, C*D]\n", " print('rnn input', x.shape)\n", "\n", " # remove padding part\n", " x, audio_len = self.rnn(x, audio_len) #[B, T, D]\n", " print('rnn output', x.shape)\n", "\n", " logits = self.fc(x) #[B, T, V + 1]\n", "\n", " #ctcdecoder need probs, not log_probs\n", " probs = F.softmax(logits)\n", "\n", " return logits, probs, audio_len\n", "\n", " def forward(self, audio, audio_len, text, text_len):\n", " \"\"\"\n", " audio: shape [B, D, T]\n", " text: shape [B, T]\n", " audio_len: shape [B]\n", " text_len: shape [B]\n", " \"\"\"\n", " return self.infer(audio, audio_len)\n", " \n", "\n", "feat_dim=161\n", "\n", "model = DeepSpeech2Test(\n", " feat_size=feat_dim,\n", " dict_size=batch_reader.dataset.vocab_size,\n", " num_conv_layers=args.num_conv_layers,\n", " num_rnn_layers=args.num_rnn_layers,\n", " rnn_size=1024,\n", " use_gru=args.use_gru,\n", " share_rnn_weights=args.share_rnn_weights,\n", " )\n", "dp_model = model\n", "#dp_model = paddle.DataParallel(model)\n", "\n", "loss_fn = DeepSpeech2Loss(batch_reader.dataset.vocab_size)" ] }, { "cell_type": "code", "execution_count": null, "id": "divided-incentive", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 22, "id": "discrete-conjunction", "metadata": {}, "outputs": [], "source": [ "audio, audio_len, text, text_len = None, None, None, None\n", "\n", "for idx, inputs in enumerate(batch_reader):\n", " audio, audio_len, text, text_len = inputs\n", "# print(idx)\n", "# print('a', audio.shape, audio.place)\n", "# print('t', text)\n", "# print('al', audio_len)\n", "# print('tl', text_len)\n", " break" ] }, { "cell_type": "code", "execution_count": 23, "id": "protected-announcement", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "conv in: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", " [163, 167, 180, 186, 186])\n", "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", "conv in: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [55, 56, 60, 62, 62])\n", "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", "conv out: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [55, 56, 60, 62, 62])\n", "conv out [5, 32, 41, 62]\n", "rnn input [5, 62, 1312]\n", "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", " return (isinstance(seq, collections.Sequence) and\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", "rnn output [5, 62, 2048]\n", "logits len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [55, 56, 60, 62, 62])\n", "loss Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [2316.82153320])\n" ] } ], "source": [ "outputs = dp_model(audio, audio_len, text, text_len)\n", "logits, _, logits_len = outputs\n", "print('logits len', logits_len)\n", "loss = loss_fn.forward(logits, text, logits_len, text_len)\n", "print('loss', loss)" ] }, { "cell_type": "code", "execution_count": 24, "id": "universal-myrtle", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: None\n", "param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: None\n", "param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: None\n", "param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: None\n", "param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n", "param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n", "param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: None\n", "param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: None\n", "param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: None\n", "param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: None\n", "param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n", "param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: None\n", "param grad: fc.bias: shape: [4299] stop_grad: False grad: None\n" ] } ], "source": [ "for n, p in dp_model.named_parameters():\n", " print(\n", " f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")" ] }, { "cell_type": "code", "execution_count": 25, "id": "referenced-double", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: [[[[ 2.1243238 1.696022 3.770659 ... 5.234652 5.4865217\n", " 4.757795 ]\n", " [ 2.651376 2.3109848 4.428488 ... 5.353201 8.703288\n", " 5.1787405 ]\n", " [ 2.7511077 1.8823049 2.1875212 ... 3.4821286 6.386543\n", " 3.5026932 ]\n", " ...\n", " [ 1.9173846 1.8623551 0.5601456 ... 2.8375719 3.8496673\n", " 2.359191 ]\n", " [ 2.3827765 2.497965 1.5914664 ... 2.220721 3.4617734\n", " 4.829253 ]\n", " [ 1.6855702 1.5040786 1.8793598 ... 4.0773935 3.176893\n", " 3.7477999 ]]]\n", "\n", "\n", " [[[ 1.8451455 2.0091445 1.5225713 ... 1.524528 0.17764974\n", " 1.0245132 ]\n", " [ 1.9388857 1.3873467 2.044691 ... 0.92544 -0.9746763\n", " -0.41603735]\n", " [ 2.6814485 2.6096234 1.6802506 ... 1.902397 1.6837387\n", " -0.96788657]\n", " ...\n", " [ 4.3675485 1.9822174 1.1695029 ... 1.4672399 3.2029557\n", " 2.6364415 ]\n", " [ 3.2536 1.1792442 -0.5618002 ... 2.101127 1.904225\n", " 3.3839993 ]\n", " [ 1.9118482 1.0651072 0.5409893 ... 2.6783593 1.6871439\n", " 4.1078367 ]]]\n", "\n", "\n", " [[[-4.412424 -1.7111907 -1.7722387 ... -4.3383503 -6.2393785\n", " -6.139402 ]\n", " [-2.260428 -1.0250616 -2.0550888 ... -5.353946 -4.29947\n", " -6.158736 ]\n", " [-1.4927872 0.7552787 -0.0702923 ... -4.485656 -4.0794134\n", " -5.416684 ]\n", " ...\n", " [ 2.9100134 4.156195 4.357041 ... -3.569804 -1.8634341\n", " -0.8772557 ]\n", " [ 1.6895763 3.4314504 4.1192107 ... -1.380024 -2.3234155\n", " -3.6650617 ]\n", " [ 2.4190075 1.007498 3.1173465 ... -0.96318084 -3.6175003\n", " -2.5240796 ]]]\n", "\n", "\n", " ...\n", "\n", "\n", " [[[-0.6865506 -0.60106415 -1.5555015 ... 2.0853553 1.900961\n", " 2.101063 ]\n", " [-0.31686288 -1.4362946 -1.4929098 ... 0.15085456 1.4540495\n", " 1.4128599 ]\n", " [-0.57852304 -0.8204216 -2.3264258 ... 1.4970423 0.54599845\n", " 1.6222539 ]\n", " ...\n", " [ 0.32624918 0.96004546 -0.7476514 ... 2.2786083 2.1000178\n", " 2.7494807 ]\n", " [-1.6967826 -0.78979015 -1.8424999 ... 1.0620685 2.0544293\n", " 2.2483966 ]\n", " [ 0.8192332 2.601636 -2.6636481 ... 0.26625186 1.7610842\n", " 1.7467536 ]]]\n", "\n", "\n", " [[[ 0.9140297 0.42424175 1.4352363 ... -2.3022954 -3.001058\n", " -2.6987422 ]\n", " [ 0.4491998 -0.10698095 1.5089144 ... -3.2831016 -3.6055021\n", " -3.6595795 ]\n", " [ 2.6818252 -1.5750014 -0.34812498 ... -4.4137015 -4.250422\n", " -3.481941 ]\n", " ...\n", " [ 1.4232106 2.9689102 3.9547806 ... -0.481165 0.28190404\n", " -1.2167063 ]\n", " [ 2.2297084 4.8198485 4.2857304 ... 0.57483846 1.4093391\n", " 0.0715822 ]\n", " [ 1.679745 4.768068 5.416195 ... 0.17254728 0.4623217\n", " 1.4772662 ]]]\n", "\n", "\n", " [[[-2.0860114 -2.9508173 -1.4945896 ... -4.067145 -2.5652342\n", " -3.5771027 ]\n", " [-2.697845 -1.9273603 -2.3885014 ... -2.196533 -2.8573706\n", " -2.0113711 ]\n", " [-2.413383 -2.7204053 -1.0502659 ... -3.001385 -3.36447\n", " -4.3225455 ]\n", " ...\n", " [ 1.2754489 0.9560999 1.5239805 ... -0.0105865 -1.00876\n", " 2.6247358 ]\n", " [ 1.1965859 1.0378222 1.1025598 ... -0.5394704 0.49838027\n", " -0.9618193 ]\n", " [ 1.1361816 1.3232857 0.687318 ... -0.23925456 -0.43679112\n", " -0.79297894]]]]\n", "param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: [ 5.9604645e-07 -3.9339066e-06 -1.0728836e-06 -1.6689301e-06\n", " 1.1920929e-06 -2.5033951e-06 -2.3841858e-07 4.7683716e-07\n", " 4.2915344e-06 -1.9073486e-06 -1.9073486e-06 3.0994415e-06\n", " -2.6822090e-06 3.3378601e-06 -4.2915344e-06 5.2452087e-06\n", " 3.8146973e-06 2.3841858e-07 7.1525574e-07 -3.6954880e-06\n", " 2.0563602e-06 -2.6226044e-06 3.0994415e-06 -3.5762787e-07\n", " -4.7683716e-06 1.2218952e-06 3.3378601e-06 -2.5629997e-06\n", " 2.3841858e-07 -1.7881393e-06 4.7683716e-07 -2.7418137e-06]\n", "param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: [ 2.363316 3.286464 1.9607866 -1.6367784 -1.6325372 -1.7729434\n", " -0.9261875 2.0950415 0.1155543 -0.8857083 0.70079553 0.33920464\n", " 2.6953902 -0.64524114 0.8845749 -1.2271115 0.6578167 -2.939814\n", " 5.5728893 -1.0917969 0.01470797 1.395206 4.8009634 -0.744532\n", " 0.944651 -1.092311 1.4877632 -3.042566 0.51686054 -5.4768667\n", " -5.628145 -1.0894046 ]\n", "param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: [ 1.5193373 1.8838218 3.7722278 0.28052303 0.5386534 -0.44620085\n", " -1.6977876 3.115642 0.03312349 -2.9121587 3.8925257 0.2288351\n", " -2.273387 -1.3597974 4.3708124 -0.23374033 0.116272 -0.7064927\n", " 6.5267463 -1.5318865 1.0288429 0.7928574 -0.24655592 -2.1116853\n", " 2.922772 -3.3462617 1.7016437 -3.5471547 0.29777628 -3.2820854\n", " -4.116946 -0.9909375 ]\n", "param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n", "param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n", "param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: [[[[ 6.20494843e-01 5.95983505e-01 -1.48909020e+00 ... -6.86620831e-01\n", " 6.71104014e-01 -1.95339048e+00]\n", " [-3.91837955e-03 1.27062631e+00 -1.63248098e+00 ... 1.07290137e+00\n", " -9.42245364e-01 -3.34277248e+00]\n", " [ 2.41821265e+00 2.36212373e-01 -1.84433365e+00 ... 1.23182368e+00\n", " 1.36039746e+00 -2.94621849e+00]\n", " ...\n", " [ 1.55153418e+00 7.25861669e-01 2.08785534e+00 ... -6.40172660e-01\n", " -3.23889256e-02 -2.30832791e+00]\n", " [ 3.69824195e+00 1.27163112e-01 4.09263194e-01 ... -8.60729575e-01\n", " -3.51897454e+00 -2.10093403e+00]\n", " [-4.94779050e-01 -3.74262631e-01 -1.19801068e+00 ... -2.05930543e+00\n", " -7.38576293e-01 -9.44581270e-01]]\n", "\n", " [[-2.04341412e+00 -3.70606273e-01 -1.40429378e+00 ... -1.71711946e+00\n", " -4.09437418e-01 -1.74107194e+00]\n", " [-8.72247815e-01 -1.06301677e+00 -9.19306517e-01 ... -2.98976970e+00\n", " -3.03250861e+00 -2.37099743e+00]\n", " [-5.00457406e-01 -1.11882675e+00 -5.91526508e-01 ... 4.23921436e-01\n", " -2.08650708e+00 -1.82109618e+00]\n", " ...\n", " [ 2.07773042e+00 1.40735030e-01 -2.60543615e-01 ... -1.55956164e-01\n", " -1.31862307e+00 -2.07174897e+00]\n", " [ 7.95007765e-01 1.14988625e-01 -1.43308258e+00 ... 8.29253554e-01\n", " -9.57888126e-01 -3.82121086e-01]\n", " [ 8.34397674e-02 1.38636863e+00 -1.21593380e+00 ... -2.65783578e-01\n", " 1.78124309e-02 -3.40287232e+00]]\n", "\n", " [[ 6.27344131e-01 5.71699142e-02 -3.58010936e+00 ... -4.53077674e-01\n", " 1.65331578e+00 2.58466601e-02]\n", " [ 2.66681361e+00 2.02069378e+00 -1.52052927e+00 ... 2.94914508e+00\n", " 1.94632411e+00 -1.06698799e+00]\n", " [ 1.57839453e+00 -1.03649735e-01 -4.22528505e+00 ... 2.28863955e+00\n", " 4.27859402e+00 3.66381669e+00]\n", " ...\n", " [-2.44603205e+00 -2.09621000e+00 -2.57623529e+00 ... 9.00211930e-01\n", " 4.30536079e+00 -2.49779320e+00]\n", " [-2.52187514e+00 -3.36546659e+00 -1.26748765e+00 ... 8.11533451e-01\n", " 2.55930424e-01 4.50821817e-02]\n", " [-3.40082574e+00 -3.26924801e+00 -5.86932135e+00 ... -1.18203712e+00\n", " 1.09565187e+00 -4.96661961e-01]]\n", "\n", " ...\n", "\n", " [[ 8.20469666e+00 6.96195841e+00 2.73753977e+00 ... 8.34498823e-01\n", " 2.56748104e+00 1.67592216e+00]\n", " [ 9.85801792e+00 8.81465149e+00 6.09280396e+00 ... 1.42389655e+00\n", " 2.92086434e+00 2.08308399e-01]\n", " [ 8.00702763e+00 7.97301006e+00 4.64527416e+00 ... 8.61916900e-01\n", " 3.55370259e+00 4.75085378e-01]\n", " ...\n", " [ 5.61662769e+00 -4.72857296e-01 -1.04519971e-01 ... -4.03000236e-01\n", " -1.66419971e+00 -1.70375630e-01]\n", " [ 4.52409792e+00 -3.70670676e-01 4.54190969e-02 ... -8.20453286e-01\n", " 9.49141383e-02 8.88008535e-01]\n", " [ 3.27219462e+00 8.93201411e-01 1.94810414e+00 ... -2.86915004e-02\n", " 1.93200278e+00 8.19505215e-01]]\n", "\n", " [[ 5.84066296e+00 6.72855520e+00 5.21399307e+00 ... 4.55058670e+00\n", " 3.19132543e+00 3.17435169e+00]\n", " [ 6.04594421e+00 6.88997173e+00 5.00542831e+00 ... 2.23561144e+00\n", " 2.76059532e+00 4.83479440e-01]\n", " [ 5.36118126e+00 4.13896275e+00 3.68701124e+00 ... 3.64462805e+00\n", " 2.80596399e+00 1.52781498e+00]\n", " ...\n", " [ 2.87856674e+00 5.84320784e-01 1.74297714e+00 ... 2.83938944e-01\n", " -2.26546407e-01 -1.18434143e+00]\n", " [ 2.08510804e+00 1.74915957e+00 1.58637917e+00 ... 6.41967297e-01\n", " -1.31319761e-01 -3.85830402e-01]\n", " [ 4.41666174e+00 2.58244562e+00 2.97712159e+00 ... 1.42317235e-01\n", " 1.68037796e+00 -6.50003672e-01]]\n", "\n", " [[ 1.05511594e+00 6.74880028e-01 -7.64639139e-01 ... -2.15282440e-01\n", " 2.07197094e+00 4.48752761e-01]\n", " [ 2.12095881e+00 3.44118834e+00 1.61375272e+00 ... -1.18487728e+00\n", " 1.88659012e+00 1.48252523e+00]\n", " [ 8.33427787e-01 4.35035896e+00 -3.59877385e-02 ... 8.70242774e-01\n", " 3.75945044e+00 -3.09408635e-01]\n", " ...\n", " [ 5.08510351e+00 4.73114061e+00 1.97346115e+00 ... -2.25924397e+00\n", " -1.26373076e+00 -1.37826729e+00]\n", " [ 6.17275095e+00 4.16016817e+00 3.15675950e+00 ... -2.02416754e+00\n", " 1.50002241e-02 1.84633851e+00]\n", " [ 7.32995272e+00 5.34601831e+00 4.58857203e+00 ... -1.88874304e+00\n", " 1.53240371e+00 7.47349262e-02]]]\n", "\n", "\n", " [[[-1.80918843e-01 -2.52616453e+00 -2.78145695e+00 ... 1.44283652e+00\n", " -1.08945215e+00 4.19084758e-01]\n", " [-9.66833949e-01 -2.41106153e+00 -3.48886085e+00 ... -1.87193304e-01\n", " 8.21905077e-01 1.89097953e+00]\n", " [-1.59118319e+00 -2.56997013e+00 -3.10426521e+00 ... 2.05900550e+00\n", " -2.78253704e-01 6.96343541e-01]\n", " ...\n", " [ 6.66302443e-02 -2.00887346e+00 -3.17550874e+00 ... 7.97579706e-01\n", " -9.71581042e-02 1.71877682e+00]\n", " [-8.01679730e-01 -2.02678037e+00 -3.21915555e+00 ... 8.35528374e-01\n", " -1.15296638e+00 4.35728967e-01]\n", " [ 1.45292446e-01 -2.15479851e+00 -1.51839817e+00 ... -3.07936192e-01\n", " -5.39051890e-01 1.13107657e+00]]\n", "\n", " [[-2.43341160e+00 -3.35346818e+00 -9.87014294e-01 ... 1.34049034e+00\n", " 2.95773447e-02 1.27177119e+00]\n", " [-2.61602497e+00 -9.76761580e-01 -2.52060473e-01 ... -1.38134825e+00\n", " 3.85564029e-01 4.57195908e-01]\n", " [-2.23676014e+00 -4.00404739e+00 -2.23409963e+00 ... -1.41846514e+00\n", " -6.58698231e-02 -3.61778140e-01]\n", " ...\n", " [-1.13604403e+00 -6.03917837e-02 -4.95491922e-01 ... 2.14673686e+00\n", " 1.21484184e+00 2.22764325e+00]\n", " [-1.05162430e+00 -1.59828448e+00 3.15489501e-01 ... 2.28046751e+00\n", " 2.39702511e+00 2.43942714e+00]\n", " [-1.27370405e+00 -2.05736399e-01 -1.12124372e+00 ... 2.21597219e+00\n", " 2.50086927e+00 1.91134131e+00]]\n", "\n", " [[-4.53170598e-01 -1.59644139e+00 -3.63470483e+00 ... -4.35066032e+00\n", " -3.79540777e+00 -1.09796596e+00]\n", " [-2.21036464e-01 -2.53353834e+00 -1.28269875e+00 ... -3.38615727e+00\n", " -2.59143281e+00 7.74220943e-01]\n", " [-6.89323783e-01 -1.44375205e+00 6.66438341e-02 ... -1.30736077e+00\n", " -1.23293114e+00 1.58148706e+00]\n", " ...\n", " [ 1.63751483e+00 -4.08427984e-01 -8.15176964e-01 ... 3.70807743e+00\n", " 2.04232907e+00 1.97716308e+00]\n", " [ 2.13261342e+00 1.85947633e+00 -8.06532025e-01 ... 1.98311245e+00\n", " 2.27003932e+00 -1.11734614e-01]\n", " [ 1.28702402e+00 3.98628891e-01 -1.63712263e+00 ... 8.00528765e-01\n", " 5.78273535e-01 -2.59924948e-01]]\n", "\n", " ...\n", "\n", " [[ 3.96233416e+00 4.66794682e+00 1.39437711e+00 ... 7.52061129e-01\n", " -1.53534544e+00 -6.67162359e-01]\n", " [ 2.33841681e+00 3.35811281e+00 9.80114818e-01 ... 1.48806703e+00\n", " 2.68609226e-01 -1.35124445e+00]\n", " [ 2.08177710e+00 4.28519583e+00 1.52450514e+00 ... 7.45321214e-01\n", " -5.04359961e-01 -1.81241560e+00]\n", " ...\n", " [ 2.95398951e-01 4.30877179e-01 -2.03731894e+00 ... -4.20221925e-01\n", " 3.29260826e-01 5.83679557e-01]\n", " [ 1.30742240e+00 -6.32183790e-01 -3.13741422e+00 ... 9.63868052e-02\n", " 2.91730791e-01 1.33400351e-01]\n", " [ 5.43292165e-01 -2.83665359e-01 -1.88138187e+00 ... 2.15468198e-01\n", " 4.90157723e-01 2.40562439e+00]]\n", "\n", " [[ 1.57632053e+00 6.27885723e+00 2.87853765e+00 ... 3.07016110e+00\n", " 1.91490650e+00 1.76274943e+00]\n", " [ 2.57776356e+00 4.07256317e+00 2.52231169e+00 ... 4.09494352e+00\n", " 2.53548074e+00 2.44395185e+00]\n", " [ 2.43037057e+00 4.35728836e+00 1.96233964e+00 ... 2.26702976e+00\n", " 2.94634581e+00 2.21452284e+00]\n", " ...\n", " [-2.72509992e-01 -8.41220498e-01 -1.89133918e+00 ... -1.80079627e+00\n", " -2.00367713e+00 -7.09145784e-01]\n", " [ 8.21575999e-01 -1.13323164e+00 -2.62418866e+00 ... -2.38889670e+00\n", " -7.83945560e-01 -1.01922750e-01]\n", " [-1.14730227e+00 -1.42182577e+00 -2.00993991e+00 ... -2.11025667e+00\n", " 1.60286129e-02 -7.26446986e-01]]\n", "\n", " [[ 4.20389509e+00 3.75917768e+00 4.97653627e+00 ... 1.23642838e+00\n", " 8.52760911e-01 1.27920091e-01]\n", " [ 5.29409122e+00 5.29002380e+00 3.96404648e+00 ... 1.91227329e+00\n", " 3.97556186e-01 1.69182217e+00]\n", " [ 4.60112572e+00 4.12772799e+00 2.10280085e+00 ... 3.24303842e+00\n", " -1.07720590e+00 -3.81854475e-01]\n", " ...\n", " [ 1.81884170e-02 -3.11472058e+00 -8.23525012e-01 ... -2.40161085e+00\n", " -4.48192549e+00 -6.14600539e-01]\n", " [ 1.16305006e+00 -1.15409636e+00 -3.48765063e+00 ... -1.97504926e+00\n", " -4.44984436e+00 -2.28429958e-01]\n", " [ 1.29197860e+00 6.17720246e-01 -5.87171853e-01 ... -1.35258228e-01\n", " -1.29259872e+00 1.30360842e-01]]]\n", "\n", "\n", " [[[-1.26687372e+00 -2.33633637e+00 -1.49625254e+00 ... 2.52396107e+00\n", " -6.68072224e-01 -1.13282454e+00]\n", " [-1.34229445e+00 -2.87080932e+00 -2.57388353e+00 ... -8.75385761e-01\n", " -1.00205469e+00 -3.58956242e+00]\n", " [-9.49853599e-01 -5.78684711e+00 -3.52962446e+00 ... 8.88233304e-01\n", " 2.25133196e-01 -1.02802217e+00]\n", " ...\n", " [-7.38113701e-01 -3.47510982e+00 -3.23011065e+00 ... -1.25624001e+00\n", " -1.63268471e+00 6.00247443e-01]\n", " [-2.29733467e+00 -5.72547615e-01 -1.98301303e+00 ... -1.90137398e+00\n", " -1.47013855e+00 -1.45779204e+00]\n", " [-2.24628520e+00 -3.36337948e+00 -3.91878939e+00 ... -1.53652275e+00\n", " -1.36285520e+00 -1.68160331e+00]]\n", "\n", " [[-8.11348319e-01 -7.17824280e-01 -1.02243233e+00 ... -2.69050407e+00\n", " -2.32403350e+00 -4.25943947e+00]\n", " [-2.35056520e+00 -2.35941172e+00 -1.24398732e+00 ... -2.08313870e+00\n", " -1.16508257e+00 -1.30353463e+00]\n", " [-2.25146723e+00 -1.94972813e+00 -1.13295293e+00 ... -2.61496377e+00\n", " -1.91106403e+00 -1.07801402e+00]\n", " ...\n", " [-2.67012739e+00 -3.20916414e+00 -2.41768575e+00 ... 2.65138328e-01\n", " -5.27612507e-01 1.44604075e+00]\n", " [-3.54237866e+00 -3.62832785e+00 -2.40270257e+00 ... -9.76106226e-02\n", " 4.67946082e-01 -7.24248111e-01]\n", " [-2.49844384e+00 -3.42463255e+00 -2.99040008e+00 ... 4.28889185e-01\n", " -7.51657963e-01 -1.00530767e+00]]\n", "\n", " [[-8.42589438e-02 1.42022014e-01 -8.51281703e-01 ... 4.21745628e-01\n", " -2.35717297e-02 -1.71374834e+00]\n", " [-1.05496287e+00 3.82416457e-01 -4.40595537e-01 ... 1.03381336e-01\n", " -1.41204190e+00 -7.58325040e-01]\n", " [-2.28930283e+00 -2.03857040e+00 -9.16261196e-01 ... -3.94939929e-01\n", " -1.07798588e+00 -1.48433352e+00]\n", " ...\n", " [-3.11473966e-01 -1.40877593e+00 -2.42908645e+00 ... 7.88682699e-01\n", " 1.24199319e+00 1.89949930e-01]\n", " [ 5.44084549e-01 -1.02425671e+00 -1.53991556e+00 ... -4.36764538e-01\n", " -5.78772545e-01 2.62665659e-01]\n", " [ 1.26812792e+00 -9.89493608e-01 -1.47972977e+00 ... 2.21440494e-02\n", " 2.79776216e-01 7.63269484e-01]]\n", "\n", " ...\n", "\n", " [[ 6.02095068e-01 5.93243122e-01 -1.06838238e+00 ... 3.56546330e+00\n", " 1.16390383e+00 -1.47593319e-01]\n", " [ 1.80458140e+00 1.68401957e+00 4.17516947e-01 ... 3.33444500e+00\n", " 1.89411759e+00 1.03220642e-01]\n", " [ 2.74264169e+00 2.92038846e+00 1.00775683e+00 ... 3.53285050e+00\n", " 2.07282662e+00 -2.56800652e-01]\n", " ...\n", " [ 4.88933468e+00 3.72433925e+00 3.58677816e+00 ... 1.98363388e+00\n", " 1.80851030e+00 8.32634747e-01]\n", " [ 4.01546288e+00 4.78934765e+00 2.94778132e+00 ... 2.99637699e+00\n", " 1.30439472e+00 3.61029744e-01]\n", " [ 3.13628030e+00 2.01894832e+00 2.82585931e+00 ... 2.54264188e+00\n", " -9.16651785e-02 9.93353873e-02]]\n", "\n", " [[ 2.35585642e+00 8.42678428e-01 1.57331872e+00 ... 3.65935063e+00\n", " 3.94066262e+00 4.89832020e+00]\n", " [ 1.85791731e+00 1.34373701e+00 1.30812299e+00 ... 2.71434736e+00\n", " 3.22004294e+00 2.99872303e+00]\n", " [ 1.67675853e+00 -4.05569375e-02 1.85539150e+00 ... 3.73934364e+00\n", " 2.98195982e+00 3.37315011e+00]\n", " ...\n", " [ 2.14539170e+00 2.86586595e+00 2.20222116e+00 ... 1.20492995e+00\n", " 2.13971066e+00 1.94932449e+00]\n", " [ 4.68422651e+00 3.80044746e+00 4.23209000e+00 ... 2.40658951e+00\n", " 2.29117441e+00 2.52368808e+00]\n", " [ 3.10694575e+00 2.49402595e+00 4.53786707e+00 ... 9.08902645e-01\n", " 1.86903965e+00 2.27776885e+00]]\n", "\n", " [[ 1.45200038e+00 5.17961740e-01 -1.58403587e+00 ... 5.07019472e+00\n", " 7.87163258e-01 1.20610237e+00]\n", " [ 3.39321136e+00 2.21043849e+00 -6.31202877e-01 ... 4.97822762e+00\n", " 9.66498017e-01 1.18883348e+00]\n", " [ 1.20627856e+00 1.82759428e+00 5.91053367e-01 ... 4.14318657e+00\n", " 5.25399208e-01 -1.16850233e+00]\n", " ...\n", " [ 1.05183899e+00 5.80030501e-01 1.89724147e+00 ... 2.54626465e+00\n", " -1.49128008e+00 -1.85064209e+00]\n", " [ 1.50983357e+00 2.85973406e+00 2.61224055e+00 ... 4.83481932e+00\n", " 9.67048705e-02 -4.37043965e-01]\n", " [ 2.57720876e+00 2.09961963e+00 4.11754288e-02 ... 3.80421424e+00\n", " -7.83308804e-01 -1.64871216e+00]]]\n", "\n", "\n", " ...\n", "\n", "\n", " [[[-1.16345096e+00 -2.53971386e+00 -8.99101734e-01 ... -4.35583591e-01\n", " -1.29671764e+00 -1.61429560e+00]\n", " [ 3.72841507e-01 3.45808208e-01 -1.82167351e+00 ... -2.14515448e+00\n", " -1.26383066e+00 -2.27464601e-01]\n", " [ 1.58568513e+00 2.58181524e+00 1.86554670e+00 ... -1.10401320e+00\n", " -3.68550658e-01 -2.58849680e-01]\n", " ...\n", " [-9.15827155e-01 -1.25424683e+00 -4.04716206e+00 ... 2.13138080e+00\n", " 2.67662477e+00 2.31014514e+00]\n", " [-3.19453120e-01 -6.71132684e-01 -1.51378751e+00 ... 1.86080432e+00\n", " 2.77418542e+00 1.22875953e+00]\n", " [-1.20453942e+00 -3.93669218e-01 -1.51751983e+00 ... 1.17620552e+00\n", " 1.95602298e+00 7.64306366e-01]]\n", "\n", " [[-8.73186827e-01 -2.12537169e+00 -1.91664994e+00 ... -2.90821463e-01\n", " 1.90896463e+00 8.02283168e-01]\n", " [-1.06389821e+00 -2.15300727e+00 -1.82113051e+00 ... -4.34280694e-01\n", " 1.53455496e+00 1.94702053e+00]\n", " [-2.08403468e+00 -4.72900331e-01 -1.10610819e+00 ... -8.79420400e-01\n", " 7.79394627e-01 2.02670670e+00]\n", " ...\n", " [-4.28208113e-01 -7.90894389e-01 -1.06713009e+00 ... 1.12579381e+00\n", " 9.61961091e-01 1.40342009e+00]\n", " [ 4.40416574e-01 -1.65901780e-02 -1.05338669e+00 ... 1.40698349e+00\n", " 9.43485856e-01 2.34856772e+00]\n", " [-1.20572495e+00 -2.03134632e+00 4.88817632e-01 ... 2.20770907e+00\n", " 1.38143206e+00 2.00714707e+00]]\n", "\n", " [[ 9.00486887e-01 -9.50459957e-01 -1.42935121e+00 ... -1.30648065e+00\n", " -2.52133775e+00 -8.87715697e-01]\n", " [ 3.73431134e+00 1.69571114e+00 5.99429727e-01 ... 6.64332986e-01\n", " -6.10453069e-01 2.06534386e+00]\n", " [ 1.59800696e+00 -4.59622175e-01 -6.73136234e-01 ... 2.18770742e-01\n", " -1.12928271e+00 4.87097502e-02]\n", " ...\n", " [ 1.92336845e+00 1.37130380e-01 -3.51048648e-01 ... 5.41638851e-01\n", " 1.06069386e+00 1.36404145e+00]\n", " [ 1.29641414e+00 -2.79530913e-01 -2.63607264e-01 ... -8.62445176e-01\n", " 1.48393130e+00 2.69196725e+00]\n", " [ 1.14442182e+00 -1.24098969e+00 3.70959163e-01 ... -1.12241995e+00\n", " 3.67927134e-01 2.55976987e+00]]\n", "\n", " ...\n", "\n", " [[ 5.32017851e+00 3.64207411e+00 3.84571218e+00 ... 3.60754800e+00\n", " 2.57500267e+00 -1.38083458e-01]\n", " [ 5.69058084e+00 3.93056583e+00 2.93337941e+00 ... 3.17091584e+00\n", " 2.34770632e+00 6.48133337e-01]\n", " [ 5.98239613e+00 6.16548634e+00 3.04750896e+00 ... 5.51510525e+00\n", " 4.34810448e+00 1.31588542e+00]\n", " ...\n", " [ 5.09930992e+00 3.32360983e+00 2.29228449e+00 ... 3.45123887e-01\n", " 1.06280947e+00 -5.93325794e-02]\n", " [ 4.19760656e+00 3.97779059e+00 1.66905916e+00 ... 3.68937254e-01\n", " 8.06131065e-02 8.08142900e-01]\n", " [ 4.52498960e+00 3.45109749e+00 1.01074433e+00 ... -2.54036248e-01\n", " 3.13675582e-01 2.13851762e+00]]\n", "\n", " [[ 6.93927193e+00 6.05758238e+00 4.60648441e+00 ... 4.32221603e+00\n", " 3.17874146e+00 1.47012353e+00]\n", " [ 7.88523865e+00 6.62228966e+00 4.77496338e+00 ... 4.45868683e+00\n", " 2.73698759e+00 2.17057824e+00]\n", " [ 7.12061214e+00 6.01714134e+00 4.52996492e+00 ... 3.97184372e+00\n", " 3.43153954e+00 1.21802723e+00]\n", " ...\n", " [ 2.85720730e+00 1.89639473e+00 1.96340394e+00 ... 1.89643729e+00\n", " 1.64856291e+00 1.15853786e+00]\n", " [ 3.88248491e+00 2.16386199e+00 1.53069091e+00 ... 2.71704245e+00\n", " 2.24890351e+00 2.22156644e+00]\n", " [ 5.27136230e+00 1.68400204e+00 2.09500480e+00 ... 2.75956345e+00\n", " 3.71970820e+00 1.69852686e+00]]\n", "\n", " [[ 2.55598164e+00 1.64588141e+00 6.70431674e-01 ... 3.24091220e+00\n", " 1.48759770e+00 -1.72001183e+00]\n", " [ 4.33942318e+00 8.40826690e-01 -7.40000725e-01 ... 7.24577069e-01\n", " 1.74327165e-01 -1.83029580e+00]\n", " [ 4.39864540e+00 2.28395438e+00 -1.90353513e-01 ... 5.58019161e+00\n", " 1.05627227e+00 -8.02519619e-01]\n", " ...\n", " [ 1.97654784e+00 3.26888156e+00 1.52879453e+00 ... 3.15013933e+00\n", " 4.66731453e+00 4.98701715e+00]\n", " [ 1.40016854e+00 3.45761251e+00 3.68359756e+00 ... 1.14207900e+00\n", " 3.32219076e+00 3.83035636e+00]\n", " [ 1.99269783e+00 2.15428829e+00 3.35396528e-01 ... 2.45916694e-01\n", " 2.13785577e+00 4.33214951e+00]]]\n", "\n", "\n", " [[[ 1.35320330e+00 5.05850911e-02 1.04915988e+00 ... 1.82023585e-01\n", " 2.72914767e-01 3.92112255e-01]\n", " [ 1.04646444e+00 7.60913491e-01 1.93323612e+00 ... 1.19493449e+00\n", " -1.44200325e-01 4.07531261e-02]\n", " [-9.88207340e-01 -1.46165287e+00 1.05884135e-01 ... -3.23057353e-01\n", " -2.28934169e+00 -7.38609374e-01]\n", " ...\n", " [ 1.01198792e+00 2.34331083e+00 1.04566610e+00 ... 1.29697472e-01\n", " -1.23878837e+00 2.21006930e-01]\n", " [-3.75360101e-01 1.53673506e+00 -1.32206869e+00 ... -2.55255580e-01\n", " -6.22699618e-01 -1.73162484e+00]\n", " [ 4.34735864e-01 5.08327007e-01 -3.49233925e-01 ... -1.04749084e+00\n", " -1.15777385e+00 -1.13671994e+00]]\n", "\n", " [[ 1.67839336e+00 -1.80224836e-01 1.02194118e+00 ... 8.44027162e-01\n", " 8.81283879e-02 -1.37762165e+00]\n", " [ 8.39694083e-01 1.32322550e+00 4.02442753e-01 ... -4.21785116e-01\n", " -9.98012185e-01 -1.11348581e+00]\n", " [ 7.64424682e-01 8.58965695e-01 2.94626594e-01 ... -6.65519595e-01\n", " -3.65677416e-01 -2.25250268e+00]\n", " ...\n", " [-1.10193872e+00 1.18070498e-01 1.04604781e-01 ... -1.44486964e+00\n", " -2.52748466e+00 -2.16131711e+00]\n", " [-1.06079710e+00 -1.48379254e+00 3.80138367e-01 ... -1.62288392e+00\n", " -2.44736362e+00 -8.78590107e-01]\n", " [ 3.44401300e-02 -2.60935068e+00 -2.35597759e-01 ... -2.41114974e+00\n", " -2.45255780e+00 -1.82384634e+00]]\n", "\n", " [[ 1.37670958e+00 1.58661580e+00 -2.85664916e-01 ... 1.49081087e+00\n", " 4.13422853e-01 1.12761199e+00]\n", " [ 1.54148173e+00 6.22704089e-01 1.41886568e+00 ... 1.59678531e+00\n", " -8.72656107e-01 1.52415514e-01]\n", " [ 3.30207205e+00 2.89925170e+00 1.91855145e+00 ... 3.18863559e+00\n", " 1.87347198e+00 9.48901057e-01]\n", " ...\n", " [-1.53920484e+00 1.77375078e-02 -1.02018684e-01 ... 1.94011092e+00\n", " -6.83587790e-01 1.49154460e+00]\n", " [-2.27719522e+00 1.02481163e+00 -2.11300224e-01 ... -8.18020821e-01\n", " 1.54248989e+00 -1.46732473e+00]\n", " [-4.50206220e-01 3.62383485e+00 1.07175660e+00 ... 4.25961137e-01\n", " 1.12405360e-01 -6.87821358e-02]]\n", "\n", " ...\n", "\n", " [[-3.40477467e-01 -2.99311423e+00 -2.12096786e+00 ... 2.27393007e+00\n", " 4.03424358e+00 3.73335361e+00]\n", " [-6.99971199e-01 -2.97719741e+00 -2.72910309e+00 ... 1.50101089e+00\n", " 2.29408574e+00 3.14105940e+00]\n", " [-1.41648722e+00 -1.86292887e+00 -1.84006739e+00 ... 2.78402638e+00\n", " 3.91481900e+00 5.32456112e+00]\n", " ...\n", " [ 5.97958088e-01 1.50512588e+00 6.23718500e-01 ... 2.83813477e+00\n", " 3.87909842e+00 3.33359623e+00]\n", " [ 1.65542316e+00 3.56163192e+00 4.01527691e+00 ... 3.38367462e+00\n", " 1.55827272e+00 2.50741863e+00]\n", " [ 2.82036042e+00 2.53322673e+00 4.38798475e+00 ... 4.64642382e+00\n", " 3.28739667e+00 3.02895570e+00]]\n", "\n", " [[-3.47941303e+00 -3.49006844e+00 -2.25583363e+00 ... 1.45181656e-01\n", " 1.52944064e+00 2.08810711e+00]\n", " [-2.27786446e+00 -4.59218550e+00 -2.74722624e+00 ... -1.73136210e+00\n", " 7.46028006e-01 1.74789345e+00]\n", " [-3.35524082e+00 -4.58244705e+00 -2.40820456e+00 ... -5.04051924e-01\n", " 1.49640536e+00 2.16613841e+00]\n", " ...\n", " [ 5.26107132e-01 2.05329061e+00 2.84252572e+00 ... 1.33222675e+00\n", " 3.87935114e+00 3.69385266e+00]\n", " [ 4.38092083e-01 2.15028906e+00 3.13363624e+00 ... 3.36048746e+00\n", " 5.36551809e+00 2.94915986e+00]\n", " [ 2.75497317e+00 3.25929213e+00 2.33522987e+00 ... 1.69926262e+00\n", " 3.93462896e+00 3.68200874e+00]]\n", "\n", " [[ 1.10951948e+00 5.31419516e-02 -1.58864903e+00 ... 5.24887085e+00\n", " 1.60273385e+00 4.90113163e+00]\n", " [-2.94517064e+00 -2.81092644e+00 -4.89631557e+00 ... 3.99868512e+00\n", " 1.40544355e+00 2.84833241e+00]\n", " [-3.51893663e-01 -3.53325534e+00 -2.21239805e+00 ... 4.26225853e+00\n", " 6.87886119e-01 2.58609629e+00]\n", " ...\n", " [ 2.92248201e+00 5.40264511e+00 4.65721560e+00 ... 5.24537373e+00\n", " 2.30406880e+00 1.29892707e+00]\n", " [ 1.43473256e+00 4.61167526e+00 3.57578802e+00 ... 5.12181854e+00\n", " 8.59923482e-01 1.38731599e+00]\n", " [-6.50881350e-01 2.18233657e+00 2.74669623e+00 ... 4.86368895e+00\n", " 1.44120216e+00 1.79993320e+00]]]\n", "\n", "\n", " [[[ 1.64106202e+00 3.54410499e-01 -3.54172409e-01 ... 2.32646990e+00\n", " 1.65043330e+00 3.45897645e-01]\n", " [ 2.16236949e+00 1.28213906e+00 2.26082468e+00 ... 6.10507369e-01\n", " 9.12241280e-01 1.27429694e-01]\n", " [ 2.07962990e+00 7.03816175e-01 2.01272345e+00 ... -2.26959705e-01\n", " 1.00041127e+00 5.87104559e-02]\n", " ...\n", " [-1.62972426e+00 -3.04028845e+00 -1.39124167e+00 ... 2.47561097e+00\n", " 2.35047388e+00 1.61532843e+00]\n", " [-1.97368932e+00 -5.44541061e-01 -5.92882216e-01 ... 1.39800012e+00\n", " 2.32770801e+00 9.96662021e-01]\n", " [-1.15636075e+00 -1.34654212e+00 -8.50648999e-01 ... 1.85655832e+00\n", " 2.05776072e+00 5.34575820e-01]]\n", "\n", " [[-1.02104437e+00 3.08469892e-01 2.81789303e-01 ... -8.24654043e-01\n", " -9.85817850e-01 -2.05517030e+00]\n", " [ 9.50192690e-01 3.35105330e-01 5.31637192e-01 ... -1.42974198e-01\n", " -1.79659498e+00 -1.58266973e+00]\n", " [-2.51316994e-01 -1.28709340e+00 3.01498562e-01 ... -1.32253516e+00\n", " -1.55507576e+00 -9.37123299e-01]\n", " ...\n", " [ 2.33016998e-01 2.92454743e+00 3.15420461e+00 ... 1.15574491e+00\n", " 1.27850962e+00 1.35487700e+00]\n", " [ 3.81013602e-01 1.44239831e+00 6.64825320e-01 ... -3.89374971e-01\n", " 1.50716826e-01 1.33641326e+00]\n", " [ 1.71373415e+00 1.67357373e+00 1.76596940e+00 ... 1.57941079e+00\n", " 1.60940981e+00 1.78091609e+00]]\n", "\n", " [[-5.16522598e+00 -1.68099070e+00 -3.24440050e+00 ... -3.46229005e+00\n", " -2.18273020e+00 -1.98621082e+00]\n", " [-3.05743694e+00 9.15392339e-01 -1.93508530e+00 ... -1.82306373e+00\n", " -2.12960863e+00 -3.45255351e+00]\n", " [-4.32777822e-01 -1.00303245e+00 -1.61397791e+00 ... -2.08376765e+00\n", " -3.72989595e-01 -1.36516929e+00]\n", " ...\n", " [-5.83641946e-01 4.14125490e+00 1.58227599e+00 ... 2.03144050e+00\n", " 2.13982654e+00 -1.81909311e+00]\n", " [-1.74230576e+00 2.39347410e+00 2.44080925e+00 ... 5.43732524e-01\n", " 2.07899213e+00 -3.71748984e-01]\n", " [ 3.80016506e-01 7.84988403e-01 1.20596504e+00 ... -2.32057095e+00\n", " -2.81265080e-01 -3.69353056e+00]]\n", "\n", " ...\n", "\n", " [[-3.48024845e+00 -2.60937548e+00 -3.84952760e+00 ... 6.68736577e-01\n", " -1.75104141e-02 -3.54720926e+00]\n", " [-2.59637117e+00 -5.18190145e+00 -2.33887696e+00 ... 9.13373232e-02\n", " -3.58282638e+00 -2.40778995e+00]\n", " [-2.50912881e+00 -1.22113395e+00 -2.34372020e+00 ... 1.40071487e+00\n", " -1.67449510e+00 -1.14655948e+00]\n", " ...\n", " [-5.75253534e+00 -6.67348385e+00 -5.05184650e+00 ... -2.73145151e+00\n", " -1.48933101e+00 -1.36807609e+00]\n", " [-3.29049587e+00 -3.73956156e+00 -2.85064268e+00 ... -3.92481357e-01\n", " -8.00529659e-01 -8.39800835e-01]\n", " [-4.30351114e+00 -4.21471930e+00 -2.41703367e+00 ... -1.27081513e+00\n", " 1.67839837e+00 8.47821474e-01]]\n", "\n", " [[-5.27856112e-01 -1.09752083e+00 3.39107156e-01 ... 2.00062895e+00\n", " 8.83528054e-01 2.57416844e-01]\n", " [-1.58655810e+00 -3.36268663e-01 1.16161990e+00 ... 1.54868484e+00\n", " 2.38878536e+00 1.84097290e+00]\n", " [ 5.96052647e-01 2.15484858e-01 1.85280466e+00 ... 2.74587560e+00\n", " 1.61432290e+00 1.13214278e+00]\n", " ...\n", " [-4.57659864e+00 -5.42679739e+00 -4.35204458e+00 ... -1.82452416e+00\n", " -2.18670201e+00 -3.91811800e+00]\n", " [-1.32477629e+00 -4.19110394e+00 -3.41308069e+00 ... 1.39622003e-01\n", " -1.59393203e+00 -9.08105671e-01]\n", " [-3.60161018e+00 -4.05932713e+00 -2.23674798e+00 ... 9.09647286e-01\n", " 9.73127842e-01 1.19991803e+00]]\n", "\n", " [[ 2.04062796e+00 7.95603275e-01 -1.28833270e+00 ... 4.64749050e+00\n", " 2.25974560e+00 1.02396965e+00]\n", " [ 1.68882537e+00 2.63353348e+00 2.53597498e-02 ... 4.69063854e+00\n", " -4.19382691e-01 2.91669458e-01]\n", " [ 7.71395087e-01 1.20833695e+00 -2.58601785e-01 ... 1.21794045e+00\n", " -1.51922226e-01 7.44265199e-01]\n", " ...\n", " [-6.66095781e+00 -4.81577682e+00 -5.39921665e+00 ... -2.20548606e+00\n", " 5.72486281e-01 -4.35207397e-01]\n", " [-7.51608658e+00 -6.67776871e+00 -3.73199415e+00 ... -1.70327055e+00\n", " 1.01334639e-02 -3.20627165e+00]\n", " [-5.73050356e+00 -2.74379373e+00 -3.70248461e+00 ... -1.09794116e+00\n", " -1.73590891e-02 -1.80156028e+00]]]]\n", "param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: [-1.4305115e-06 0.0000000e+00 -4.0531158e-06 -1.6689301e-06\n", " 2.3841858e-07 -7.1525574e-07 1.1920929e-06 1.5497208e-06\n", " -2.3841858e-07 1.6689301e-06 9.5367432e-07 9.5367432e-07\n", " -2.6226044e-06 1.1920929e-06 1.3113022e-06 1.9669533e-06\n", " -4.7683716e-07 1.1920929e-06 -1.6689301e-06 -1.5497208e-06\n", " -2.2649765e-06 4.7683716e-07 2.3841858e-06 -3.5762787e-06\n", " 2.3841858e-07 2.1457672e-06 -3.5762787e-07 8.3446503e-07\n", " -3.5762787e-07 -7.1525574e-07 2.6524067e-06 -1.1920929e-06]\n", "param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: [-3.7669735 1.5226867 1.759756 4.501629 -2.2077336 0.18411277\n", " 1.3558264 -1.0269645 3.9628277 3.9300344 -2.80754 1.8462183\n", " -0.03385968 2.1284049 0.46124816 -4.364863 0.78491163 0.25565645\n", " -5.3538237 3.2606194 0.79100513 -1.4652673 2.769378 1.2283417\n", " -4.7466464 -1.3404545 -6.9374166 0.710248 2.0944448 0.4334769\n", " -0.24313992 0.31392363]\n", "param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: [-0.6251638 2.833331 0.6993131 3.7106915 -2.262496 0.7390424\n", " 0.5360477 -2.803875 2.1646228 2.117193 -1.9988279 1.5135905\n", " -2.0181084 2.6450465 0.06302822 -3.0530102 1.4788482 0.5941844\n", " -3.1690063 1.8753575 -0.0737313 -2.7806277 -0.04483938 0.16129279\n", " -1.2960215 -0.38020235 -0.55218065 0.10754502 2.065371 -1.4703183\n", " -0.40964937 -1.4454535 ]\n", "param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n", "param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: [[-0.46178514 0.1095643 0.06441769 ... 0.42020613 -0.34181893\n", " -0.0658682 ]\n", " [-0.03619978 0.21653323 0.01727325 ... 0.05731536 -0.37822944\n", " -0.05464617]\n", " [-0.32397318 0.04158126 -0.08091418 ... 0.0928297 -0.06518176\n", " -0.40110156]\n", " ...\n", " [-0.2702023 0.05126935 0.11825457 ... 0.0069707 -0.36951366\n", " 0.37071258]\n", " [-0.11326203 0.19305304 -0.133317 ... -0.13030824 -0.09068564\n", " 0.32735693]\n", " [-0.04543798 0.09902512 -0.10745425 ... -0.06685166 -0.3055201\n", " 0.0752247 ]]\n", "param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.07338604 0.64991236 0.5465856 ... 0.507725 0.14061031\n", " 0.3020359 ]\n", "param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.41395143 -0.28493872 0.36796764 ... 0.2387953 0.06732331\n", " 0.16263628]\n", "param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.09370177 -0.12264141 -0.08237482 ... -0.50241685 -0.149155\n", " -0.25661892]\n", " [-0.37426725 0.44987115 0.10685667 ... -0.65946174 -0.4499248\n", " -0.17545304]\n", " [-0.03753807 0.33422717 0.12750985 ... 0.05405155 -0.17648363\n", " 0.05315325]\n", " ...\n", " [ 0.15721183 0.03064088 -0.00751081 ... 0.27183983 0.3881693\n", " -0.01544908]\n", " [ 0.26047793 0.16917065 0.00915196 ... 0.18076143 -0.05080506\n", " 0.14791614]\n", " [ 0.19052255 0.03642382 -0.14313167 ... 0.2611448 0.20763844\n", " 0.26846847]]\n", "param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.4139514 -0.28493875 0.36796758 ... 0.23879525 0.06732336\n", " 0.16263627]\n", "param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[ 0.04214853 -0.1710323 0.17557406 ... 0.11926915 0.21577051\n", " -0.30598596]\n", " [-0.02370887 -0.03498494 -0.05991999 ... -0.06049232 -0.14527473\n", " -0.5335691 ]\n", " [-0.21417995 -0.10263194 -0.05903128 ... -0.26958284 0.05936668\n", " 0.25522667]\n", " ...\n", " [ 0.31594425 -0.29487017 0.15871571 ... 0.3504135 -0.1418606\n", " -0.07482046]\n", " [ 0.22316164 0.7682122 -0.22191924 ... -0.00535548 -0.6497105\n", " -0.2011079 ]\n", " [-0.05800886 0.13750821 0.02450509 ... 0.245736 0.07425706\n", " -0.17761081]]\n", "param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.45080703 0.19005743 0.077441 ... -0.24504453 0.19666554\n", " -0.10503208]\n", "param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237206 0.03389215 ... -0.35602498 0.25528812\n", " 0.11344345]\n", "param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.48457903 0.04466334 -0.19785863 ... -0.0254025 -0.10338341\n", " -0.29202533]\n", " [-0.15261276 0.00412052 0.22198747 ... 0.22460426 -0.03752084\n", " 0.05170784]\n", " [-0.09337254 0.02530848 0.1263681 ... -0.02056236 0.33342454\n", " -0.08760723]\n", " ...\n", " [-0.28645608 -0.19169135 -0.1361257 ... -0.00444204 -0.06552711\n", " -0.14726155]\n", " [ 0.21883707 0.2049045 0.23723911 ... 0.4626113 -0.14110637\n", " 0.02569831]\n", " [ 0.37554163 -0.19249167 0.14591683 ... 0.25602737 0.40088275\n", " 0.41056633]]\n", "param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237211 0.0338921 ... -0.35602498 0.2552881\n", " 0.11344352]\n", "param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[-0.28007814 -0.09206 -0.01297755 ... -0.2557205 -0.2693453\n", " 0.05862035]\n", " [-0.34194735 -0.01383794 -0.06490533 ... -0.11063005 0.16226721\n", " -0.3197178 ]\n", " [-0.3646778 0.15443833 0.02241019 ... -0.15093157 -0.09886418\n", " -0.44295847]\n", " ...\n", " [-0.01041886 -0.57636976 -0.03988511 ... -0.2260822 0.49646813\n", " -0.15528557]\n", " [-0.19385241 -0.56451964 -0.05551083 ... -0.5638106 0.43611372\n", " -0.61484563]\n", " [ 0.1051331 -0.4762463 0.11194798 ... -0.26766616 -0.30734932\n", " 0.17856634]]\n", "param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.02791309 -0.992517 0.63012564 ... -1.1830902 1.4646478\n", " 1.6333911 ]\n", "param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.10834587 -1.7079136 0.81259465 ... -1.4478713 1.455745\n", " 2.069446 ]\n", "param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", "param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.14363798 -0.06933184 0.02901152 ... -0.19233373 -0.03206367\n", " -0.00845779]\n", " [-0.44314507 -0.8921327 -1.031872 ... -0.558997 -0.53070104\n", " -0.855925 ]\n", " [ 0.15673254 0.28793585 0.13351494 ... 0.38433537 0.5040767\n", " 0.11303265]\n", " ...\n", " [-0.22923109 -0.62508404 -0.6195032 ... -0.6876448 -0.41718128\n", " -0.74844164]\n", " [ 0.18024652 0.45618314 0.81391454 ... 0.5780604 0.87566674\n", " 0.71526295]\n", " [ 0.3763076 0.54033077 0.9940485 ... 1.087821 0.72288674\n", " 1.2852117 ]]\n", "param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.10834593 -1.7079139 0.8125948 ... -1.4478711 1.4557447\n", " 2.0694466 ]\n", "param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", "param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", "param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: [[ 1.4382483e-02 2.0160766e-02 1.2322801e-02 ... 1.0075266e-02\n", " 7.4421698e-03 -2.3925617e+01]\n", " [ 3.7887424e-02 5.7105277e-02 2.8803380e-02 ... 2.4820438e-02\n", " 1.8560058e-02 -5.0687141e+01]\n", " [ 4.5566272e-02 5.4415584e-02 3.2858539e-02 ... 3.2725763e-02\n", " 2.1536341e-02 -6.1036335e+01]\n", " ...\n", " [ 2.8015019e-02 3.5967816e-02 2.3228688e-02 ... 2.1284629e-02\n", " 1.3860047e-02 -5.2543671e+01]\n", " [ 2.8445240e-02 4.2448867e-02 2.7125146e-02 ... 2.2253662e-02\n", " 1.7470375e-02 -4.3619675e+01]\n", " [ 4.7438074e-02 5.8287360e-02 3.4546286e-02 ... 3.0827176e-02\n", " 2.2168703e-02 -6.7901680e+01]]\n", "param grad: fc.bias: shape: [4299] stop_grad: False grad: [ 8.8967547e-02 1.0697905e-01 6.5251388e-02 ... 6.1503030e-02\n", " 4.3404289e-02 -1.3512518e+02]\n" ] } ], "source": [ "loss.backward(retain_graph=False)\n", "for n, p in dp_model.named_parameters():\n", " print(\n", " f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")" ] }, { "cell_type": "code", "execution_count": 26, "id": "selected-crazy", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.]\n" ] } ], "source": [ "print(loss.grad)" ] }, { "cell_type": "code", "execution_count": null, "id": "bottom-engineer", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "stuffed-yeast", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 5 }