{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "academic-surname",
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "from paddle import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fundamental-treasure",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    }
   ],
   "source": [
    "L = nn.Linear(256, 2048)\n",
    "L2 = nn.Linear(2048, 256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "consolidated-elephant",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "moderate-noise",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "float64\n",
      "Tensor(shape=[2, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[[-1.54171216, -2.61531472, -1.79881978, ..., -0.31395876,  0.56513089, -0.44516513],\n",
      "         [-0.79492962,  1.91157901,  0.66567147, ...,  0.54825783, -1.01471853, -0.84924090],\n",
      "         [-1.22556651, -0.36225814,  0.65063190, ...,  0.65726501,  0.05563191,  0.09009409],\n",
      "         ...,\n",
      "         [ 0.38615900, -0.77905393,  0.99732304, ..., -1.38463700, -3.32365036, -1.31089687],\n",
      "         [ 0.05579993,  0.06885809, -1.66662002, ..., -0.23346378, -3.29372883,  1.30561364],\n",
      "         [ 1.90676069,  1.95093191, -0.28849599, ..., -0.06860496,  0.95347673,  1.00475824]],\n",
      "\n",
      "        [[-0.91453546,  0.55298805, -1.06146812, ..., -0.86378336,  1.00454640,  1.26062179],\n",
      "         [ 0.10223761,  0.81301165,  2.36865163, ...,  0.16821407,  0.29240361,  1.05408621],\n",
      "         [-1.33196676,  1.94433689,  0.01934209, ...,  0.48036841,  0.51585966,  1.22893548],\n",
      "         ...,\n",
      "         [-0.19558455, -0.47075930,  0.90796155, ..., -1.28598249, -0.24321797,  0.17734711],\n",
      "         [ 0.89819717, -1.39516675,  0.17138045, ...,  2.39761519,  1.76364994, -0.52177650],\n",
      "         [ 0.94122332, -0.18581429,  1.36099780, ...,  0.67647684, -0.04699665,  1.51205540]]])\n",
      "tensor([[[-1.5417, -2.6153, -1.7988,  ..., -0.3140,  0.5651, -0.4452],\n",
      "         [-0.7949,  1.9116,  0.6657,  ...,  0.5483, -1.0147, -0.8492],\n",
      "         [-1.2256, -0.3623,  0.6506,  ...,  0.6573,  0.0556,  0.0901],\n",
      "         ...,\n",
      "         [ 0.3862, -0.7791,  0.9973,  ..., -1.3846, -3.3237, -1.3109],\n",
      "         [ 0.0558,  0.0689, -1.6666,  ..., -0.2335, -3.2937,  1.3056],\n",
      "         [ 1.9068,  1.9509, -0.2885,  ..., -0.0686,  0.9535,  1.0048]],\n",
      "\n",
      "        [[-0.9145,  0.5530, -1.0615,  ..., -0.8638,  1.0045,  1.2606],\n",
      "         [ 0.1022,  0.8130,  2.3687,  ...,  0.1682,  0.2924,  1.0541],\n",
      "         [-1.3320,  1.9443,  0.0193,  ...,  0.4804,  0.5159,  1.2289],\n",
      "         ...,\n",
      "         [-0.1956, -0.4708,  0.9080,  ..., -1.2860, -0.2432,  0.1773],\n",
      "         [ 0.8982, -1.3952,  0.1714,  ...,  2.3976,  1.7636, -0.5218],\n",
      "         [ 0.9412, -0.1858,  1.3610,  ...,  0.6765, -0.0470,  1.5121]]])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    }
   ],
   "source": [
    "x = np.random.randn(2, 51, 256)\n",
    "print(x.dtype)\n",
    "px = paddle.to_tensor(x, dtype='float32')\n",
    "tx = torch.tensor(x, dtype=torch.float32)\n",
    "print(px)\n",
    "print(tx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cooked-progressive",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "mechanical-prisoner",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n",
    "t_norm_ff = data['norm_ff']\n",
    "t_ff_out = data['ff_out']\n",
    "t_ff_l_x = data['ff_l_x']\n",
    "t_ff_l_a_x = data['ff_l_a_x']\n",
    "t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
    "t_ps = data['ps']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "indie-marriage",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "assured-zambia",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n",
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "L.set_state_dict({'weight': t_ps[0].T, 'bias': t_ps[1]})\n",
    "L2.set_state_dict({'weight': t_ps[2].T, 'bias': t_ps[3]})\n",
    "\n",
    "ps = []\n",
    "for n, p in L.named_parameters():\n",
    "   ps.append(p)\n",
    "\n",
    "for n, p in L2.state_dict().items():\n",
    "    ps.append(p)\n",
    "    \n",
    "for p, tp in zip(ps, t_ps):\n",
    "    print(np.allclose(p.numpy(), tp.T))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "committed-jacob",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "extreme-traffic",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "optimum-milwaukee",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "viral-indian",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n",
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "# data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n",
    "# t_norm_ff = data['norm_ff']\n",
    "# t_ff_out = data['ff_out']\n",
    "# t_ff_l_x = data['ff_l_x']\n",
    "# t_ff_l_a_x = data['ff_l_a_x']\n",
    "# t_ff_l_a_l_x = data['ff_l_a_l_x']\n",
    "# t_ps = data['ps']\n",
    "TL = torch.nn.Linear(256, 2048)\n",
    "TL2 = torch.nn.Linear(2048, 256)\n",
    "TL.load_state_dict({'weight': torch.tensor(t_ps[0]), 'bias': torch.tensor(t_ps[1])})\n",
    "TL2.load_state_dict({'weight': torch.tensor(t_ps[2]), 'bias': torch.tensor(t_ps[3])})\n",
    "\n",
    "# for n, p in TL.named_parameters():\n",
    "#    print(n, p)\n",
    "# for n, p in TL2.named_parameters():\n",
    "#    print(n, p)\n",
    "\n",
    "ps = []\n",
    "for n, p in TL.state_dict().items():\n",
    "    ps.append(p.data.numpy())\n",
    "    \n",
    "for n, p in TL2.state_dict().items():\n",
    "    ps.append(p.data.numpy())\n",
    "    \n",
    "for p, tp in zip(ps, t_ps):\n",
    "    print(np.allclose(p, tp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "skilled-vietnamese",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[ 0.67277956  0.08313607 -0.62761104 ... -0.17480263  0.42718208\n",
      "   -0.5787626 ]\n",
      "  [ 0.91516656  0.5393416   1.7159258  ...  0.06144593  0.06486575\n",
      "   -0.03350811]\n",
      "  [ 0.438351    0.6227843   0.24096036 ...  1.0912522  -0.90929437\n",
      "   -1.012989  ]\n",
      "  ...\n",
      "  [ 0.68631977  0.14240924  0.10763275 ... -0.11513516  0.48065388\n",
      "    0.04070369]\n",
      "  [-0.9525228   0.23197874  0.31264272 ...  0.5312439   0.18773697\n",
      "   -0.8450228 ]\n",
      "  [ 0.42024016 -0.04561988  0.54541194 ... -0.41933843 -0.00436018\n",
      "   -0.06663495]]\n",
      "\n",
      " [[-0.11638781 -0.33566502 -0.20887226 ...  0.17423287 -0.9195841\n",
      "   -0.8161046 ]\n",
      "  [-0.3469874   0.88269687 -0.11887559 ... -0.15566081  0.16357468\n",
      "   -0.20766167]\n",
      "  [-0.3847657   0.3984318  -0.06963477 ... -0.00360622  1.2360432\n",
      "   -0.26811332]\n",
      "  ...\n",
      "  [ 0.08230796 -0.46158582  0.54582864 ...  0.15747628 -0.44790155\n",
      "    0.06020184]\n",
      "  [-0.8095085   0.43163058 -0.42837143 ...  0.8627463   0.90656304\n",
      "    0.15847842]\n",
      "  [-1.485811   -0.18216592 -0.8882585  ...  0.32596245  0.7822631\n",
      "   -0.6460344 ]]]\n",
      "[[[ 0.67278004  0.08313602 -0.6276114  ... -0.17480245  0.42718196\n",
      "   -0.5787625 ]\n",
      "  [ 0.91516703  0.5393413   1.7159253  ...  0.06144581  0.06486579\n",
      "   -0.03350812]\n",
      "  [ 0.43835106  0.62278455  0.24096027 ...  1.0912521  -0.9092943\n",
      "   -1.0129892 ]\n",
      "  ...\n",
      "  [ 0.6863195   0.14240888  0.10763284 ... -0.11513527  0.48065376\n",
      "    0.04070365]\n",
      "  [-0.9525231   0.23197863  0.31264275 ...  0.53124386  0.18773702\n",
      "   -0.84502304]\n",
      "  [ 0.42024007 -0.04561983  0.545412   ... -0.41933888 -0.00436005\n",
      "   -0.066635  ]]\n",
      "\n",
      " [[-0.11638767 -0.33566508 -0.20887226 ...  0.17423296 -0.9195838\n",
      "   -0.8161046 ]\n",
      "  [-0.34698725  0.88269705 -0.11887549 ... -0.15566081  0.16357464\n",
      "   -0.20766166]\n",
      "  [-0.3847657   0.3984319  -0.06963488 ... -0.00360619  1.2360426\n",
      "   -0.26811326]\n",
      "  ...\n",
      "  [ 0.08230786 -0.4615857   0.5458287  ...  0.15747619 -0.44790167\n",
      "    0.06020182]\n",
      "  [-0.8095083   0.4316307  -0.42837155 ...  0.862746    0.9065631\n",
      "    0.15847899]\n",
      "  [-1.485811   -0.18216613 -0.8882584  ...  0.32596254  0.7822631\n",
      "   -0.6460344 ]]]\n",
      "True\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "y = L(px)\n",
    "print(y.numpy())\n",
    "\n",
    "ty = TL(tx)\n",
    "print(ty.data.numpy())\n",
    "print(np.allclose(px.numpy(), tx.detach().numpy()))\n",
    "print(np.allclose(y.numpy(), ty.detach().numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "incorrect-allah",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "prostate-cameroon",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "governmental-surge",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.04476918  0.554463   -0.3027508  ... -0.49600336  0.3751858\n",
      "   0.8254095 ]\n",
      " [ 0.95594174 -0.29528382 -1.2899452  ...  0.43718258  0.05584608\n",
      "  -0.06974669]]\n",
      "[[ 0.04476918  0.5544631  -0.3027507  ... -0.49600336  0.37518573\n",
      "   0.8254096 ]\n",
      " [ 0.95594174 -0.29528376 -1.2899454  ...  0.4371827   0.05584623\n",
      "  -0.0697467 ]]\n",
      "True\n",
      "False\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "x = np.random.randn(2, 256)\n",
    "px = paddle.to_tensor(x, dtype='float32')\n",
    "tx = torch.tensor(x, dtype=torch.float32)\n",
    "y = L(px)\n",
    "print(y.numpy())\n",
    "ty = TL(tx)\n",
    "print(ty.data.numpy())\n",
    "print(np.allclose(px.numpy(), tx.detach().numpy()))\n",
    "print(np.allclose(y.numpy(), ty.detach().numpy()))\n",
    "print(np.allclose(y.numpy(), ty.detach().numpy(), atol=1e-5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "confidential-jacket",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "improved-civilization",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5e7e7c9fde8350084abf1898cf52651cfc84b17a\n"
     ]
    }
   ],
   "source": [
    "print(paddle.version.commit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d1e2d3b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__builtins__',\n",
       " '__cached__',\n",
       " '__doc__',\n",
       " '__file__',\n",
       " '__loader__',\n",
       " '__name__',\n",
       " '__package__',\n",
       " '__spec__',\n",
       " 'commit',\n",
       " 'full_version',\n",
       " 'istaged',\n",
       " 'major',\n",
       " 'minor',\n",
       " 'mkl',\n",
       " 'patch',\n",
       " 'rc',\n",
       " 'show',\n",
       " 'with_mkl']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(paddle.version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c880c719",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.0\n"
     ]
    }
   ],
   "source": [
    "print(paddle.version.full_version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f26977bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "commit: 5e7e7c9fde8350084abf1898cf52651cfc84b17a\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "print(paddle.version.show())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "04ad47f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.6.0\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e1e03830",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__builtins__',\n",
       " '__cached__',\n",
       " '__doc__',\n",
       " '__file__',\n",
       " '__loader__',\n",
       " '__name__',\n",
       " '__package__',\n",
       " '__spec__',\n",
       " '__version__',\n",
       " 'cuda',\n",
       " 'debug',\n",
       " 'git_version',\n",
       " 'hip']"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(torch.version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4ad0389b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'b31f58de6fa8bbda5353b3c77d9be4914399724d'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.version.git_version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "7870ea10",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'10.2'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db8ee5a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6321ec2a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}