{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ff6ff1e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "33af5f76",
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9b566b73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fatal: destination path 'warp-ctc' already exists and is not an empty directory.\r\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/SeanNaren/warp-ctc.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4a087a09",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
     ]
    }
   ],
   "source": [
    "%cd warp-ctc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f55dc29a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mkdir -p build"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fe79f4cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n"
     ]
    }
   ],
   "source": [
    "cd build"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3d25c718",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-- cuda found TRUE\n",
      "-- Building shared library with GPU support\n",
      "-- Configuring done\n",
      "-- Generating done\n",
      "-- Build files have been written to: /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n"
     ]
    }
   ],
   "source": [
    "!cmake .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7a4238f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 11%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
      "[ 33%] Built target warpctc\n",
      "[ 44%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
      "[ 55%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
      "[ 77%] Built target test_cpu\n",
      "[100%] Built target test_gpu\n"
     ]
    }
   ],
   "source": [
    "!make -j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "31761a31",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
     ]
    }
   ],
   "source": [
    "cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f53316f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding\n"
     ]
    }
   ],
   "source": [
    "cd pytorch_binding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "084f1e49",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "running install\n",
      "running bdist_egg\n",
      "running egg_info\n",
      "writing warpctc_pytorch.egg-info/PKG-INFO\n",
      "writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n",
      "writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n",
      "writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
      "installing library code to build/bdist.linux-x86_64/egg\n",
      "running install_lib\n",
      "running build_py\n",
      "running build_ext\n",
      "building 'warpctc_pytorch._warp_ctc' extension\n",
      "Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n",
      "Compiling objects...\n",
      "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
      "ninja: no work to do.\n",
      "g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n",
      "creating build/bdist.linux-x86_64/egg\n",
      "creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
      "copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
      "copying build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
      "byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/__init__.py to __init__.cpython-39.pyc\n",
      "creating stub loader for warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so\n",
      "byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/_warp_ctc.py to _warp_ctc.cpython-39.pyc\n",
      "creating build/bdist.linux-x86_64/egg/EGG-INFO\n",
      "copying warpctc_pytorch.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
      "copying warpctc_pytorch.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
      "copying warpctc_pytorch.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
      "copying warpctc_pytorch.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
      "writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n",
      "zip_safe flag not set; analyzing archive contents...\n",
      "warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n",
      "creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n",
      "removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n",
      "Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
      "removing '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' (and everything under it)\n",
      "creating /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
      "Extracting warpctc_pytorch-0.1-py3.9-linux-x86_64.egg to /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages\n",
      "warpctc-pytorch 0.1 is already the active version in easy-install.pth\n",
      "\n",
      "Installed /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
      "Processing dependencies for warpctc-pytorch==0.1\n",
      "Finished processing dependencies for warpctc-pytorch==0.1\n"
     ]
    }
   ],
   "source": [
    "!python setup.py install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ee4ca9e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Python 3.9.5\r\n"
     ]
    }
   ],
   "source": [
    "!python -V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "59255ed8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
     ]
    }
   ],
   "source": [
    "cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "1dae09b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import warpctc_pytorch as wp\n",
    "import paddle.nn as pn\n",
    "import paddle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "83d0762e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.10.0+cu102'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "62501e2c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.2.1'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "paddle.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9e8e0f40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 1, 5])\n",
      "2.4628584384918213\n",
      "[[[ 0.17703122 -0.70812464  0.17703122  0.17703122  0.17703122]]\n",
      "\n",
      " [[ 0.17703122  0.17703122 -0.70812464  0.17703122  0.17703122]]]\n"
     ]
    }
   ],
   "source": [
    "# warpctc_pytorch CTCLoss\n",
    "probs = torch.FloatTensor([[\n",
    "        [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
    "    ]]).transpose(0, 1).contiguous()\n",
    "print(probs.size())\n",
    "labels = torch.IntTensor([1, 2])\n",
    "label_sizes = torch.IntTensor([2])\n",
    "probs_sizes = torch.IntTensor([2])\n",
    "probs.requires_grad_(True)\n",
    "bs = probs.size(1)\n",
    "\n",
    "ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
    "cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n",
    "cost = cost.sum() / bs\n",
    "print(cost.item())\n",
    "cost.backward()\n",
    "print(probs.grad.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2cd46569",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.4628584384918213\n",
      "[[[ 0.1770312 -0.7081248  0.1770312  0.1770312  0.1770312]]\n",
      "\n",
      " [[ 0.1770312  0.1770312 -0.7081248  0.1770312  0.1770312]]]\n"
     ]
    }
   ],
   "source": [
    "# pytorch CTCLoss\n",
    "probs = torch.FloatTensor([[\n",
    "        [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
    "    ]]).transpose(0, 1).contiguous()\n",
    "labels = torch.IntTensor([1, 2])\n",
    "label_sizes = torch.IntTensor([2])\n",
    "probs_sizes = torch.IntTensor([2])\n",
    "probs.requires_grad_(True)\n",
    "bs = probs.size(1)\n",
    "\n",
    "log_probs = torch.log_softmax(probs, axis=-1)\n",
    "\n",
    "ctc_loss1 = nn.CTCLoss(reduction='none')\n",
    "cost = ctc_loss1(log_probs, labels, probs_sizes, label_sizes)\n",
    "cost = cost.sum() / bs\n",
    "print(cost.item())\n",
    "cost.backward()\n",
    "print(probs.grad.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "85c3461a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 1, 5]\n",
      "[1, 2]\n",
      "2.4628584384918213\n",
      "[[[ 0.17703122 -0.70812464  0.17703122  0.17703122  0.17703122]]\n",
      "\n",
      " [[ 0.17703122  0.17703122 -0.70812464  0.17703122  0.17703122]]]\n"
     ]
    }
   ],
   "source": [
    "# Paddle CTCLoss\n",
    "paddle.set_device('cpu')\n",
    "probs = paddle.to_tensor([[\n",
    "        [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n",
    "    ]]).transpose([1,0,2])\n",
    "print(probs.shape) # (T, B, D)\n",
    "labels = paddle.to_tensor([[1, 2]], dtype='int32') #(B,L)\n",
    "print(labels.shape)\n",
    "label_sizes = paddle.to_tensor([2], dtype='int64')\n",
    "probs_sizes = paddle.to_tensor([2], dtype='int64')\n",
    "bs = paddle.shape(probs)[1]\n",
    "probs.stop_gradient=False\n",
    "\n",
    "ctc_loss = pn.CTCLoss(reduction='none')\n",
    "cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n",
    "cost = cost.sum() / bs\n",
    "print(cost.item())\n",
    "cost.backward()\n",
    "print(probs.grad.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cdf76c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "2c305eaf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 1, 5])\n",
      "2.4628584384918213\n",
      "[[[ 0.17703117 -0.7081247   0.17703117  0.17703117  0.17703117]]\n",
      "\n",
      " [[ 0.17703117  0.17703117 -0.7081247   0.17703117  0.17703117]]]\n"
     ]
    }
   ],
   "source": [
    "# warpctc_pytorch CTCLoss, log_softmax idempotent\n",
    "probs = torch.FloatTensor([[\n",
    "        [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
    "    ]]).transpose(0, 1).contiguous()\n",
    "print(probs.size())\n",
    "labels = torch.IntTensor([1, 2])\n",
    "label_sizes = torch.IntTensor([2])\n",
    "probs_sizes = torch.IntTensor([2])\n",
    "probs.requires_grad_(True)\n",
    "bs = probs.size(1)\n",
    "\n",
    "ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
    "\n",
    "log_probs = torch.log_softmax(probs, axis=-1)\n",
    "cost = ctc_loss(log_probs, labels, probs_sizes, label_sizes)\n",
    "cost = cost.sum() / bs\n",
    "print(cost.item())\n",
    "cost.backward()\n",
    "print(probs.grad.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "443336f0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}