|
|
|
@ -0,0 +1,520 @@
|
|
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 1,
|
|
|
|
|
"id": "ff6ff1e0",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"%load_ext autoreload"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 2,
|
|
|
|
|
"id": "33af5f76",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"%autoreload 2"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 3,
|
|
|
|
|
"id": "9b566b73",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"Cloning into 'warp-ctc'...\n",
|
|
|
|
|
"remote: Enumerating objects: 829, done.\u001b[K\n",
|
|
|
|
|
"remote: Total 829 (delta 0), reused 0 (delta 0), pack-reused 829\u001b[K\n",
|
|
|
|
|
"Receiving objects: 100% (829/829), 388.85 KiB | 140.00 KiB/s, done.\n",
|
|
|
|
|
"Resolving deltas: 100% (419/419), done.\n",
|
|
|
|
|
"Checking connectivity... done.\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"!git clone https://github.com/SeanNaren/warp-ctc.git"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 4,
|
|
|
|
|
"id": "4a087a09",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"%cd warp-ctc"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 5,
|
|
|
|
|
"id": "f55dc29a",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"mkdir -p build"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 6,
|
|
|
|
|
"id": "fe79f4cf",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"cd build"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 7,
|
|
|
|
|
"id": "3d25c718",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"-- The C compiler identification is GNU 5.4.0\n",
|
|
|
|
|
"-- The CXX compiler identification is GNU 5.4.0\n",
|
|
|
|
|
"-- Check for working C compiler: /usr/bin/cc\n",
|
|
|
|
|
"-- Check for working C compiler: /usr/bin/cc -- works\n",
|
|
|
|
|
"-- Detecting C compiler ABI info\n",
|
|
|
|
|
"-- Detecting C compiler ABI info - done\n",
|
|
|
|
|
"-- Detecting C compile features\n",
|
|
|
|
|
"-- Detecting C compile features - done\n",
|
|
|
|
|
"-- Check for working CXX compiler: /usr/bin/c++\n",
|
|
|
|
|
"-- Check for working CXX compiler: /usr/bin/c++ -- works\n",
|
|
|
|
|
"-- Detecting CXX compiler ABI info\n",
|
|
|
|
|
"-- Detecting CXX compiler ABI info - done\n",
|
|
|
|
|
"-- Detecting CXX compile features\n",
|
|
|
|
|
"-- Detecting CXX compile features - done\n",
|
|
|
|
|
"-- Looking for pthread.h\n",
|
|
|
|
|
"-- Looking for pthread.h - found\n",
|
|
|
|
|
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD\n",
|
|
|
|
|
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed\n",
|
|
|
|
|
"-- Looking for pthread_create in pthreads\n",
|
|
|
|
|
"-- Looking for pthread_create in pthreads - not found\n",
|
|
|
|
|
"-- Looking for pthread_create in pthread\n",
|
|
|
|
|
"-- Looking for pthread_create in pthread - found\n",
|
|
|
|
|
"-- Found Threads: TRUE \n",
|
|
|
|
|
"-- Found CUDA: /usr/local/cuda (found suitable version \"10.2\", minimum required is \"6.5\") \n",
|
|
|
|
|
"-- cuda found TRUE\n",
|
|
|
|
|
"-- Building shared library with GPU support\n",
|
|
|
|
|
"-- Configuring done\n",
|
|
|
|
|
"-- Generating done\n",
|
|
|
|
|
"-- Build files have been written to: /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"!cmake .."
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 8,
|
|
|
|
|
"id": "7a4238f1",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"[ 11%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o\u001b[0m\n",
|
|
|
|
|
"[ 22%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_ctc_entrypoint.cu.o\u001b[0m\n",
|
|
|
|
|
"\u001b[35m\u001b[1mScanning dependencies of target warpctc\u001b[0m\n",
|
|
|
|
|
"[ 33%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
|
|
|
|
|
"[ 33%] Built target warpctc\n",
|
|
|
|
|
"[ 44%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/test_gpu.dir/tests/test_gpu_generated_test_gpu.cu.o\u001b[0m\n",
|
|
|
|
|
"\u001b[35m\u001b[1mScanning dependencies of target test_cpu\u001b[0m\n",
|
|
|
|
|
"[ 55%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/test_cpu.cpp.o\u001b[0m\n",
|
|
|
|
|
"[ 66%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/random.cpp.o\u001b[0m\n",
|
|
|
|
|
"[ 77%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
|
|
|
|
|
"[ 77%] Built target test_cpu\n",
|
|
|
|
|
"\u001b[35m\u001b[1mScanning dependencies of target test_gpu\u001b[0m\n",
|
|
|
|
|
"[ 88%] \u001b[32mBuilding CXX object CMakeFiles/test_gpu.dir/tests/random.cpp.o\u001b[0m\n",
|
|
|
|
|
"[100%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
|
|
|
|
|
"[100%] Built target test_gpu\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"!make -j"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 10,
|
|
|
|
|
"id": "31761a31",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"cd .."
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 11,
|
|
|
|
|
"id": "f53316f6",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"cd pytorch_binding"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 12,
|
|
|
|
|
"id": "084f1e49",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"running install\n",
|
|
|
|
|
"running bdist_egg\n",
|
|
|
|
|
"running egg_info\n",
|
|
|
|
|
"creating warpctc_pytorch.egg-info\n",
|
|
|
|
|
"writing warpctc_pytorch.egg-info/PKG-INFO\n",
|
|
|
|
|
"writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n",
|
|
|
|
|
"writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n",
|
|
|
|
|
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
|
|
|
|
|
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
|
|
|
|
|
"installing library code to build/bdist.linux-x86_64/egg\n",
|
|
|
|
|
"running install_lib\n",
|
|
|
|
|
"running build_py\n",
|
|
|
|
|
"creating build\n",
|
|
|
|
|
"creating build/lib.linux-x86_64-3.9\n",
|
|
|
|
|
"creating build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
|
|
|
|
|
"copying warpctc_pytorch/__init__.py -> build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
|
|
|
|
|
"running build_ext\n",
|
|
|
|
|
"building 'warpctc_pytorch._warp_ctc' extension\n",
|
|
|
|
|
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9\n",
|
|
|
|
|
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src\n",
|
|
|
|
|
"Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n",
|
|
|
|
|
"Compiling objects...\n",
|
|
|
|
|
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
|
|
|
|
|
"[1/1] c++ -MMD -MF /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o.d -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -I/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/TH -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/include/python3.9 -c -c /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/src/binding.cpp -o /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -std=c++14 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1011\"' -DTORCH_EXTENSION_NAME=_warp_ctc -D_GLIBCXX_USE_CXX11_ABI=0\n",
|
|
|
|
|
"g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n",
|
|
|
|
|
"creating build/bdist.linux-x86_64\n",
|
|
|
|
|
"creating build/bdist.linux-x86_64/egg\n",
|
|
|
|
|
"creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
|
|
|
|
|
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
|
|
|
|
|
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
|
|
|
|
|
"byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/__init__.py to __init__.cpython-39.pyc\n",
|
|
|
|
|
"creating stub loader for warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so\n",
|
|
|
|
|
"byte-compiling build/bdist.linux-x86_64/egg/warpctc_pytorch/_warp_ctc.py to _warp_ctc.cpython-39.pyc\n",
|
|
|
|
|
"creating build/bdist.linux-x86_64/egg/EGG-INFO\n",
|
|
|
|
|
"copying warpctc_pytorch.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
|
|
|
|
|
"copying warpctc_pytorch.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
|
|
|
|
|
"copying warpctc_pytorch.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
|
|
|
|
|
"copying warpctc_pytorch.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n",
|
|
|
|
|
"writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n",
|
|
|
|
|
"zip_safe flag not set; analyzing archive contents...\n",
|
|
|
|
|
"warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n",
|
|
|
|
|
"creating dist\n",
|
|
|
|
|
"creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n",
|
|
|
|
|
"removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n",
|
|
|
|
|
"Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
|
|
|
|
|
"removing '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' (and everything under it)\n",
|
|
|
|
|
"creating /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
|
|
|
|
|
"Extracting warpctc_pytorch-0.1-py3.9-linux-x86_64.egg to /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages\n",
|
|
|
|
|
"warpctc-pytorch 0.1 is already the active version in easy-install.pth\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"Installed /workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
|
|
|
|
|
"Processing dependencies for warpctc-pytorch==0.1\n",
|
|
|
|
|
"Finished processing dependencies for warpctc-pytorch==0.1\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"!python setup.py install"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 13,
|
|
|
|
|
"id": "ee4ca9e3",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"Python 3.9.5\r\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"!python -V"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 14,
|
|
|
|
|
"id": "59255ed8",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"cd .."
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 15,
|
|
|
|
|
"id": "1dae09b9",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"grep: warning: GREP_OPTIONS is deprecated; please use an alias or script\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"import torch\n",
|
|
|
|
|
"import torch.nn as nn\n",
|
|
|
|
|
"import warpctc_pytorch as wp\n",
|
|
|
|
|
"import paddle.nn as pn\n",
|
|
|
|
|
"import paddle"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 16,
|
|
|
|
|
"id": "83d0762e",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"'1.10.0+cu102'"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 16,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"torch.__version__"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 17,
|
|
|
|
|
"id": "62501e2c",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"'2.2.0'"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 17,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"paddle.__version__"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 18,
|
|
|
|
|
"id": "9e8e0f40",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"torch.Size([2, 1, 5])\n",
|
|
|
|
|
"2.4628584384918213\n",
|
|
|
|
|
"[[[ 0.17703122 -0.70812464 0.17703122 0.17703122 0.17703122]]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" [[ 0.17703122 0.17703122 -0.70812464 0.17703122 0.17703122]]]\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"probs = torch.FloatTensor([[\n",
|
|
|
|
|
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
|
|
|
|
|
" ]]).transpose(0, 1).contiguous()\n",
|
|
|
|
|
"print(probs.size())\n",
|
|
|
|
|
"labels = torch.IntTensor([1, 2])\n",
|
|
|
|
|
"label_sizes = torch.IntTensor([2])\n",
|
|
|
|
|
"probs_sizes = torch.IntTensor([2])\n",
|
|
|
|
|
"probs.requires_grad_(True)\n",
|
|
|
|
|
"bs = probs.size(1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
|
|
|
|
|
"cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n",
|
|
|
|
|
"cost = cost.sum() / bs\n",
|
|
|
|
|
"print(cost.item())\n",
|
|
|
|
|
"cost.backward()\n",
|
|
|
|
|
"print(probs.grad.numpy())"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 19,
|
|
|
|
|
"id": "2cd46569",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"2.4628584384918213\n",
|
|
|
|
|
"[[[ 0.1770312 -0.7081248 0.1770312 0.1770312 0.1770312]]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" [[ 0.1770312 0.1770312 -0.7081248 0.1770312 0.1770312]]]\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"probs = torch.FloatTensor([[\n",
|
|
|
|
|
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
|
|
|
|
|
" ]]).transpose(0, 1).contiguous()\n",
|
|
|
|
|
"labels = torch.IntTensor([1, 2])\n",
|
|
|
|
|
"label_sizes = torch.IntTensor([2])\n",
|
|
|
|
|
"probs_sizes = torch.IntTensor([2])\n",
|
|
|
|
|
"probs.requires_grad_(True)\n",
|
|
|
|
|
"bs = probs.size(1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"log_probs = torch.log_softmax(probs, axis=-1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"ctc_loss1 = nn.CTCLoss(reduction='none')\n",
|
|
|
|
|
"cost = ctc_loss1(log_probs, labels, probs_sizes, label_sizes)\n",
|
|
|
|
|
"cost = cost.sum() / bs\n",
|
|
|
|
|
"print(cost.item())\n",
|
|
|
|
|
"cost.backward()\n",
|
|
|
|
|
"print(probs.grad.numpy())"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 22,
|
|
|
|
|
"id": "85c3461a",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"[2, 1, 5]\n",
|
|
|
|
|
"[1, 2]\n",
|
|
|
|
|
"2.4628584384918213\n",
|
|
|
|
|
"[[[ 0.17703122 -0.70812464 0.17703122 0.17703122 0.17703122]]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" [[ 0.17703122 0.17703122 -0.70812464 0.17703122 0.17703122]]]\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"paddle.set_device('cpu')\n",
|
|
|
|
|
"probs = paddle.to_tensor([[\n",
|
|
|
|
|
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n",
|
|
|
|
|
" ]]).transpose([1,0,2])\n",
|
|
|
|
|
"print(probs.shape) # (T, B, D)\n",
|
|
|
|
|
"labels = paddle.to_tensor([[1, 2]], dtype='int32') #(B,L)\n",
|
|
|
|
|
"print(labels.shape)\n",
|
|
|
|
|
"label_sizes = paddle.to_tensor([2], dtype='int64')\n",
|
|
|
|
|
"probs_sizes = paddle.to_tensor([2], dtype='int64')\n",
|
|
|
|
|
"bs = paddle.shape(probs)[1]\n",
|
|
|
|
|
"probs.stop_gradient=False\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"ctc_loss = pn.CTCLoss(reduction='none')\n",
|
|
|
|
|
"cost = ctc_loss(probs, labels, probs_sizes, label_sizes)\n",
|
|
|
|
|
"cost = cost.sum() / bs\n",
|
|
|
|
|
"print(cost.item())\n",
|
|
|
|
|
"cost.backward()\n",
|
|
|
|
|
"print(probs.grad.numpy())"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "d390cd91",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": []
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "Python 3 (ipykernel)",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
"version": "3.9.5"
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 5
|
|
|
|
|
}
|