Merge pull request #1483 from zh794390558/doc

[doc] update ctc loss compare
pull/1485/head
Hui Zhang 4 years ago committed by GitHub
commit a942226066
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,12 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook entry: python .pre-commit-hooks/copyright-check.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin).*(\.cpp|\.h|\.py)$ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -80,6 +80,7 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix): def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix) print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = [] json_lines = []
@ -128,6 +129,7 @@ def create_manifest(data_dir, manifest_path_prefix):
print(f"{total_text / total_sec} text/sec", file=f) print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f) print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(base_url, data_list, target_dir, manifest_path, def prepare_dataset(base_url, data_list, target_dir, manifest_path,
target_data): target_data):
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
@ -164,6 +166,7 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path,
# create the manifest file # create the manifest file
create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path) create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path)
def main(): def main():
if args.target_dir.startswith('~'): if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir) args.target_dir = os.path.expanduser(args.target_dir)
@ -184,5 +187,6 @@ def main():
print("Manifest prepare done!") print("Manifest prepare done!")
if __name__ == '__main__': if __name__ == '__main__':
main() main()

@ -30,12 +30,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Cloning into 'warp-ctc'...\n", "fatal: destination path 'warp-ctc' already exists and is not an empty directory.\r\n"
"remote: Enumerating objects: 829, done.\u001b[K\n",
"remote: Total 829 (delta 0), reused 0 (delta 0), pack-reused 829\u001b[K\n",
"Receiving objects: 100% (829/829), 388.85 KiB | 140.00 KiB/s, done.\n",
"Resolving deltas: 100% (419/419), done.\n",
"Checking connectivity... done.\n"
] ]
} }
], ],
@ -99,30 +94,6 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"-- The C compiler identification is GNU 5.4.0\n",
"-- The CXX compiler identification is GNU 5.4.0\n",
"-- Check for working C compiler: /usr/bin/cc\n",
"-- Check for working C compiler: /usr/bin/cc -- works\n",
"-- Detecting C compiler ABI info\n",
"-- Detecting C compiler ABI info - done\n",
"-- Detecting C compile features\n",
"-- Detecting C compile features - done\n",
"-- Check for working CXX compiler: /usr/bin/c++\n",
"-- Check for working CXX compiler: /usr/bin/c++ -- works\n",
"-- Detecting CXX compiler ABI info\n",
"-- Detecting CXX compiler ABI info - done\n",
"-- Detecting CXX compile features\n",
"-- Detecting CXX compile features - done\n",
"-- Looking for pthread.h\n",
"-- Looking for pthread.h - found\n",
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD\n",
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed\n",
"-- Looking for pthread_create in pthreads\n",
"-- Looking for pthread_create in pthreads - not found\n",
"-- Looking for pthread_create in pthread\n",
"-- Looking for pthread_create in pthread - found\n",
"-- Found Threads: TRUE \n",
"-- Found CUDA: /usr/local/cuda (found suitable version \"10.2\", minimum required is \"6.5\") \n",
"-- cuda found TRUE\n", "-- cuda found TRUE\n",
"-- Building shared library with GPU support\n", "-- Building shared library with GPU support\n",
"-- Configuring done\n", "-- Configuring done\n",
@ -145,20 +116,11 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[ 11%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o\u001b[0m\n", "[ 11%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
"[ 22%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_ctc_entrypoint.cu.o\u001b[0m\n",
"\u001b[35m\u001b[1mScanning dependencies of target warpctc\u001b[0m\n",
"[ 33%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
"[ 33%] Built target warpctc\n", "[ 33%] Built target warpctc\n",
"[ 44%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/test_gpu.dir/tests/test_gpu_generated_test_gpu.cu.o\u001b[0m\n", "[ 44%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
"\u001b[35m\u001b[1mScanning dependencies of target test_cpu\u001b[0m\n", "[ 55%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
"[ 55%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/test_cpu.cpp.o\u001b[0m\n",
"[ 66%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/random.cpp.o\u001b[0m\n",
"[ 77%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
"[ 77%] Built target test_cpu\n", "[ 77%] Built target test_cpu\n",
"\u001b[35m\u001b[1mScanning dependencies of target test_gpu\u001b[0m\n",
"[ 88%] \u001b[32mBuilding CXX object CMakeFiles/test_gpu.dir/tests/random.cpp.o\u001b[0m\n",
"[100%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
"[100%] Built target test_gpu\n" "[100%] Built target test_gpu\n"
] ]
} }
@ -169,7 +131,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 9,
"id": "31761a31", "id": "31761a31",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -187,7 +149,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 10,
"id": "f53316f6", "id": "f53316f6",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -205,7 +167,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 11,
"id": "084f1e49", "id": "084f1e49",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -216,29 +178,20 @@
"running install\n", "running install\n",
"running bdist_egg\n", "running bdist_egg\n",
"running egg_info\n", "running egg_info\n",
"creating warpctc_pytorch.egg-info\n",
"writing warpctc_pytorch.egg-info/PKG-INFO\n", "writing warpctc_pytorch.egg-info/PKG-INFO\n",
"writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n", "writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n",
"writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n", "writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n", "writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"installing library code to build/bdist.linux-x86_64/egg\n", "installing library code to build/bdist.linux-x86_64/egg\n",
"running install_lib\n", "running install_lib\n",
"running build_py\n", "running build_py\n",
"creating build\n",
"creating build/lib.linux-x86_64-3.9\n",
"creating build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
"copying warpctc_pytorch/__init__.py -> build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
"running build_ext\n", "running build_ext\n",
"building 'warpctc_pytorch._warp_ctc' extension\n", "building 'warpctc_pytorch._warp_ctc' extension\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src\n",
"Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n", "Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n",
"Compiling objects...\n", "Compiling objects...\n",
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
"[1/1] c++ -MMD -MF /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o.d -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -I/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/TH -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/include/python3.9 -c -c /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/src/binding.cpp -o /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -std=c++14 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1011\"' -DTORCH_EXTENSION_NAME=_warp_ctc -D_GLIBCXX_USE_CXX11_ABI=0\n", "ninja: no work to do.\n",
"g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n", "g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n",
"creating build/bdist.linux-x86_64\n",
"creating build/bdist.linux-x86_64/egg\n", "creating build/bdist.linux-x86_64/egg\n",
"creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n", "creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n", "copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
@ -254,7 +207,6 @@
"writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n", "writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n",
"zip_safe flag not set; analyzing archive contents...\n", "zip_safe flag not set; analyzing archive contents...\n",
"warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n", "warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n",
"creating dist\n",
"creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n", "creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n",
"removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n", "removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n",
"Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n", "Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
@ -275,7 +227,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 12,
"id": "ee4ca9e3", "id": "ee4ca9e3",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -293,7 +245,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 13,
"id": "59255ed8", "id": "59255ed8",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -311,21 +263,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 22,
"id": "1dae09b9", "id": "1dae09b9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"grep: warning: GREP_OPTIONS is deprecated; please use an alias or script\n"
]
}
],
"source": [ "source": [
"import torch\n", "import torch\n",
"import torch.nn as nn\n", "import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import warpctc_pytorch as wp\n", "import warpctc_pytorch as wp\n",
"import paddle.nn as pn\n", "import paddle.nn as pn\n",
"import paddle" "import paddle"
@ -333,7 +278,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 15,
"id": "83d0762e", "id": "83d0762e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -343,7 +288,7 @@
"'1.10.0+cu102'" "'1.10.0+cu102'"
] ]
}, },
"execution_count": 16, "execution_count": 15,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -354,17 +299,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 16,
"id": "62501e2c", "id": "62501e2c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'2.2.0'" "'2.2.1'"
] ]
}, },
"execution_count": 17, "execution_count": 16,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -375,7 +320,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 17,
"id": "9e8e0f40", "id": "9e8e0f40",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -392,6 +337,7 @@
} }
], ],
"source": [ "source": [
"# warpctc_pytorch CTCLoss\n",
"probs = torch.FloatTensor([[\n", "probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n", " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n", " ]]).transpose(0, 1).contiguous()\n",
@ -412,7 +358,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 18,
"id": "2cd46569", "id": "2cd46569",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -428,6 +374,7 @@
} }
], ],
"source": [ "source": [
"# pytorch CTCLoss\n",
"probs = torch.FloatTensor([[\n", "probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n", " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n", " ]]).transpose(0, 1).contiguous()\n",
@ -449,7 +396,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 27,
"id": "85c3461a", "id": "85c3461a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -467,6 +414,7 @@
} }
], ],
"source": [ "source": [
"# Paddle CTCLoss\n",
"paddle.set_device('cpu')\n", "paddle.set_device('cpu')\n",
"probs = paddle.to_tensor([[\n", "probs = paddle.to_tensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n", " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n",
@ -490,7 +438,55 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "d390cd91", "id": "8cdf76c2",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2c305eaf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 1, 5])\n",
"2.4628584384918213\n",
"[[[ 0.17703117 -0.7081247 0.17703117 0.17703117 0.17703117]]\n",
"\n",
" [[ 0.17703117 0.17703117 -0.7081247 0.17703117 0.17703117]]]\n"
]
}
],
"source": [
"# warpctc_pytorch CTCLoss, log_softmax idempotent\n",
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
"print(probs.size())\n",
"labels = torch.IntTensor([1, 2])\n",
"label_sizes = torch.IntTensor([2])\n",
"probs_sizes = torch.IntTensor([2])\n",
"probs.requires_grad_(True)\n",
"bs = probs.size(1)\n",
"\n",
"ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
"\n",
"log_probs = torch.log_softmax(probs, axis=-1)\n",
"cost = ctc_loss(log_probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "443336f0",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

@ -22,19 +22,17 @@ Authors
* qingenz123@126.com (Qingen ZHAO) 2022 * qingenz123@126.com (Qingen ZHAO) 2022
""" """
import os
import logging
import argparse import argparse
import xml.etree.ElementTree as et
import glob import glob
import json import json
from ami_splits import get_AMI_split import logging
import os
import xml.etree.ElementTree as et
from distutils.util import strtobool from distutils.util import strtobool
from dataio import ( from ami_splits import get_AMI_split
load_pkl, from dataio import load_pkl
save_pkl, ) from dataio import save_pkl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SAMPLERATE = 16000 SAMPLERATE = 16000

@ -12,28 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Make VoxCeleb1 trial of kaldi format Make VoxCeleb1 trial of kaldi format
this script creat the test trial from kaldi trial voxceleb1_test_v2.txt or official trial veri_test2.txt this script creat the test trial from kaldi trial voxceleb1_test_v2.txt or official trial veri_test2.txt
to kaldi trial format to kaldi trial format
""" """
import argparse import argparse
import codecs import codecs
import os import os
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--voxceleb_trial", parser.add_argument(
"--voxceleb_trial",
default="voxceleb1_test_v2", default="voxceleb1_test_v2",
type=str, type=str,
help="VoxCeleb trial file. Default we use the kaldi trial voxceleb1_test_v2.txt") help="VoxCeleb trial file. Default we use the kaldi trial voxceleb1_test_v2.txt"
parser.add_argument("--trial", )
parser.add_argument(
"--trial",
default="data/test/trial", default="data/test/trial",
type=str, type=str,
help="Kaldi format trial file") help="Kaldi format trial file")
args = parser.parse_args() args = parser.parse_args()
def main(voxceleb_trial, trial): def main(voxceleb_trial, trial):
""" """
VoxCeleb provide several trial file, which format is different with kaldi format. VoxCeleb provide several trial file, which format is different with kaldi format.
@ -58,7 +60,9 @@ def main(voxceleb_trial, trial):
""" """
print("Start convert the voxceleb trial to kaldi format") print("Start convert the voxceleb trial to kaldi format")
if not os.path.exists(voxceleb_trial): if not os.path.exists(voxceleb_trial):
raise RuntimeError("{} does not exist. Pleas input the correct file path".format(voxceleb_trial)) raise RuntimeError(
"{} does not exist. Pleas input the correct file path".format(
voxceleb_trial))
trial_dirname = os.path.dirname(trial) trial_dirname = os.path.dirname(trial)
if not os.path.exists(trial_dirname): if not os.path.exists(trial_dirname):
@ -77,5 +81,6 @@ def main(voxceleb_trial, trial):
w.write("{} {} {}\n".format(utt_id1, utt_id2, target)) w.write("{} {} {}\n".format(utt_id1, utt_id2, target))
print("Convert the voxceleb trial to kaldi format successfully") print("Convert the voxceleb trial to kaldi format successfully")
if __name__ == "__main__": if __name__ == "__main__":
main(args.voxceleb_trial, args.trial) main(args.voxceleb_trial, args.trial)

@ -11,14 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

@ -413,7 +413,8 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool): def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("invalid sample rate, please input --sr 8000 or --sr 16000") logger.error(
"invalid sample rate, please input --sr 8000 or --sr 16000")
return False return False
if isinstance(audio_file, (str, os.PathLike)): if isinstance(audio_file, (str, os.PathLike)):

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from io import BytesIO from io import BytesIO
from typing import List
import numpy as np import numpy as np

@ -23,10 +23,11 @@ Credits
This code is adapted from https://github.com/nryant/dscore This code is adapted from https://github.com/nryant/dscore
""" """
import argparse import argparse
from distutils.util import strtobool
import os import os
import re import re
import subprocess import subprocess
from distutils.util import strtobool
import numpy as np import numpy as np
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")

Loading…
Cancel
Save