You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/docs/topic/ctc/ctc_loss_compare.ipynb

517 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ff6ff1e0",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "33af5f76",
"metadata": {},
"outputs": [],
"source": [
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9b566b73",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fatal: destination path 'warp-ctc' already exists and is not an empty directory.\r\n"
]
}
],
"source": [
"!git clone https://github.com/SeanNaren/warp-ctc.git"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4a087a09",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
]
}
],
"source": [
"%cd warp-ctc"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f55dc29a",
"metadata": {},
"outputs": [],
"source": [
"mkdir -p build"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fe79f4cf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n"
]
}
],
"source": [
"cd build"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3d25c718",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-- cuda found TRUE\n",
"-- Building shared library with GPU support\n",
"-- Configuring done\n",
"-- Generating done\n",
"-- Build files have been written to: /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n"
]
}
],
"source": [
"!cmake .."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7a4238f1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 11%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
"[ 33%] Built target warpctc\n",
"[ 44%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
"[ 55%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
"[ 77%] Built target test_cpu\n",
"[100%] Built target test_gpu\n"
]
}
],
"source": [
"!make -j"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "31761a31",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
]
}
],
"source": [
"cd .."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f53316f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding\n"
]
}
],
"source": [
"cd pytorch_binding"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "084f1e49",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"running install\n",
"running bdist_egg\n",
"running egg_info\n",
"writing warpctc_pytorch.egg-info/PKG-INFO\n",
"writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n",
"writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"installing library code to build/bdist.linux-x86_64/egg\n",
"running install_lib\n",
"running build_py\n",
"running build_ext\n",
"building 'warpctc_pytorch._warp_ctc' extension\n",
"Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n",
"Compiling objects...\n",
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
"ninja: no work to do.\n",
"g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n",
"creating build/bdist.linux-x86_64/egg\n",
"creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/__init__.py to __init__.cpython-39.pyc\n",
"creating stub loader for warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so\n",
"byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/_warp_ctc.py to _warp_ctc.cpython-39.pyc\n",
"creating build/bdist.linux-x86_64/egg/EGG-INFO\n",
"copying warpctc_pytorch.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
"copying warpctc_pytorch.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
"copying warpctc_pytorch.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
"copying warpctc_pytorch.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
"writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n",
"zip_safe flag not set; analyzing archive contents...\n",
"warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n",
"creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n",
"removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n",
"Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
"removing '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' (and everything under it)\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
"Extracting warpctc_pytorch-0.1-py3.9-linux-x86_64.egg to /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages\n",
"warpctc-pytorch 0.1 is already the active version in easy-install.pth\n",
"\n",
"Installed /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
"Processing dependencies for warpctc-pytorch==0.1\n",
"Finished processing dependencies for warpctc-pytorch==0.1\n"
]
}
],
"source": [
"!python setup.py install"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ee4ca9e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python 3.9.5\r\n"
]
}
],
"source": [
"!python -V"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "59255ed8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
]
}
],
"source": [
"cd .."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "1dae09b9",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import warpctc_pytorch as wp\n",
"import paddle.nn as pn\n",
"import paddle"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "83d0762e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.10.0+cu102'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "62501e2c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.2.1'"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"paddle.__version__"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "9e8e0f40",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 1, 5])\n",
"2.4628584384918213\n",
"[[[ 0.17703122 -0.70812464 0.17703122 0.17703122 0.17703122]]\n",
"\n",
" [[ 0.17703122 0.17703122 -0.70812464 0.17703122 0.17703122]]]\n"
]
}
],
"source": [
"# warpctc_pytorch CTCLoss\n",
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
"print(probs.size())\n",
"labels = torch.IntTensor([1, 2])\n",
"label_sizes = torch.IntTensor([2])\n",
"probs_sizes = torch.IntTensor([2])\n",
"probs.requires_grad_(True)\n",
"bs = probs.size(1)\n",
"\n",
"ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
"cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "2cd46569",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.4628584384918213\n",
"[[[ 0.1770312 -0.7081248 0.1770312 0.1770312 0.1770312]]\n",
"\n",
" [[ 0.1770312 0.1770312 -0.7081248 0.1770312 0.1770312]]]\n"
]
}
],
"source": [
"# pytorch CTCLoss\n",
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
"labels = torch.IntTensor([1, 2])\n",
"label_sizes = torch.IntTensor([2])\n",
"probs_sizes = torch.IntTensor([2])\n",
"probs.requires_grad_(True)\n",
"bs = probs.size(1)\n",
"\n",
"log_probs = torch.log_softmax(probs, axis=-1)\n",
"\n",
"ctc_loss1 = nn.CTCLoss(reduction='none')\n",
"cost = ctc_loss1(log_probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "85c3461a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2, 1, 5]\n",
"[1, 2]\n",
"2.4628584384918213\n",
"[[[ 0.17703122 -0.70812464 0.17703122 0.17703122 0.17703122]]\n",
"\n",
" [[ 0.17703122 0.17703122 -0.70812464 0.17703122 0.17703122]]]\n"
]
}
],
"source": [
"# Paddle CTCLoss\n",
"paddle.set_device('cpu')\n",
"probs = paddle.to_tensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n",
" ]]).transpose([1,0,2])\n",
"print(probs.shape) # (T, B, D)\n",
"labels = paddle.to_tensor([[1, 2]], dtype='int32') #BL)\n",
"print(labels.shape)\n",
"label_sizes = paddle.to_tensor([2], dtype='int64')\n",
"probs_sizes = paddle.to_tensor([2], dtype='int64')\n",
"bs = paddle.shape(probs)[1]\n",
"probs.stop_gradient=False\n",
"\n",
"ctc_loss = pn.CTCLoss(reduction='none')\n",
"cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8cdf76c2",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2c305eaf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 1, 5])\n",
"2.4628584384918213\n",
"[[[ 0.17703117 -0.7081247 0.17703117 0.17703117 0.17703117]]\n",
"\n",
" [[ 0.17703117 0.17703117 -0.7081247 0.17703117 0.17703117]]]\n"
]
}
],
"source": [
"# warpctc_pytorch CTCLoss, log_softmax idempotent\n",
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
"print(probs.size())\n",
"labels = torch.IntTensor([1, 2])\n",
"label_sizes = torch.IntTensor([2])\n",
"probs_sizes = torch.IntTensor([2])\n",
"probs.requires_grad_(True)\n",
"bs = probs.size(1)\n",
"\n",
"ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
"\n",
"log_probs = torch.log_softmax(probs, axis=-1)\n",
"cost = ctc_loss(log_probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "443336f0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}