resolve setup.py conflicts, test=doc

pull/1497/head
lym0302 4 years ago
commit e5aa24fa5a

@ -50,12 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook entry: python .pre-commit-hooks/copyright-check.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin).*(\.cpp|\.h|\.py)$ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -80,6 +80,7 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix): def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix) print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = [] json_lines = []
@ -128,6 +129,7 @@ def create_manifest(data_dir, manifest_path_prefix):
print(f"{total_text / total_sec} text/sec", file=f) print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f) print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(base_url, data_list, target_dir, manifest_path, def prepare_dataset(base_url, data_list, target_dir, manifest_path,
target_data): target_data):
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
@ -164,6 +166,7 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path,
# create the manifest file # create the manifest file
create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path) create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path)
def main(): def main():
if args.target_dir.startswith('~'): if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir) args.target_dir = os.path.expanduser(args.target_dir)
@ -184,5 +187,6 @@ def main():
print("Manifest prepare done!") print("Manifest prepare done!")
if __name__ == '__main__': if __name__ == '__main__':
main() main()

@ -5,4 +5,4 @@ cfg_path: # [optional]
ckpt_path: # [optional] ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
force_yes: True force_yes: True
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'

@ -15,7 +15,7 @@ decode_method:
force_yes: True force_yes: True
am_predictor_conf: am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: True
switch_ir_optim: True switch_ir_optim: True

@ -29,4 +29,4 @@ voc_stat:
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'

@ -15,7 +15,7 @@ speaker_dict:
spk_id: 0 spk_id: 0
am_predictor_conf: am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: False enable_mkldnn: False
switch_ir_optim: False switch_ir_optim: False
@ -30,7 +30,7 @@ voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 voc_sample_rate: 24000
voc_predictor_conf: voc_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: False enable_mkldnn: False
switch_ir_optim: False switch_ir_optim: False

@ -30,12 +30,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Cloning into 'warp-ctc'...\n", "fatal: destination path 'warp-ctc' already exists and is not an empty directory.\r\n"
"remote: Enumerating objects: 829, done.\u001b[K\n",
"remote: Total 829 (delta 0), reused 0 (delta 0), pack-reused 829\u001b[K\n",
"Receiving objects: 100% (829/829), 388.85 KiB | 140.00 KiB/s, done.\n",
"Resolving deltas: 100% (419/419), done.\n",
"Checking connectivity... done.\n"
] ]
} }
], ],
@ -99,30 +94,6 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"-- The C compiler identification is GNU 5.4.0\n",
"-- The CXX compiler identification is GNU 5.4.0\n",
"-- Check for working C compiler: /usr/bin/cc\n",
"-- Check for working C compiler: /usr/bin/cc -- works\n",
"-- Detecting C compiler ABI info\n",
"-- Detecting C compiler ABI info - done\n",
"-- Detecting C compile features\n",
"-- Detecting C compile features - done\n",
"-- Check for working CXX compiler: /usr/bin/c++\n",
"-- Check for working CXX compiler: /usr/bin/c++ -- works\n",
"-- Detecting CXX compiler ABI info\n",
"-- Detecting CXX compiler ABI info - done\n",
"-- Detecting CXX compile features\n",
"-- Detecting CXX compile features - done\n",
"-- Looking for pthread.h\n",
"-- Looking for pthread.h - found\n",
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD\n",
"-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed\n",
"-- Looking for pthread_create in pthreads\n",
"-- Looking for pthread_create in pthreads - not found\n",
"-- Looking for pthread_create in pthread\n",
"-- Looking for pthread_create in pthread - found\n",
"-- Found Threads: TRUE \n",
"-- Found CUDA: /usr/local/cuda (found suitable version \"10.2\", minimum required is \"6.5\") \n",
"-- cuda found TRUE\n", "-- cuda found TRUE\n",
"-- Building shared library with GPU support\n", "-- Building shared library with GPU support\n",
"-- Configuring done\n", "-- Configuring done\n",
@ -145,20 +116,11 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[ 11%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o\u001b[0m\n", "[ 11%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
"[ 22%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_ctc_entrypoint.cu.o\u001b[0m\n",
"\u001b[35m\u001b[1mScanning dependencies of target warpctc\u001b[0m\n",
"[ 33%] \u001b[32m\u001b[1mLinking CXX shared library libwarpctc.so\u001b[0m\n",
"[ 33%] Built target warpctc\n", "[ 33%] Built target warpctc\n",
"[ 44%] \u001b[34m\u001b[1mBuilding NVCC (Device) object CMakeFiles/test_gpu.dir/tests/test_gpu_generated_test_gpu.cu.o\u001b[0m\n", "[ 44%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
"\u001b[35m\u001b[1mScanning dependencies of target test_cpu\u001b[0m\n", "[ 55%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
"[ 55%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/test_cpu.cpp.o\u001b[0m\n",
"[ 66%] \u001b[32mBuilding CXX object CMakeFiles/test_cpu.dir/tests/random.cpp.o\u001b[0m\n",
"[ 77%] \u001b[32m\u001b[1mLinking CXX executable test_cpu\u001b[0m\n",
"[ 77%] Built target test_cpu\n", "[ 77%] Built target test_cpu\n",
"\u001b[35m\u001b[1mScanning dependencies of target test_gpu\u001b[0m\n",
"[ 88%] \u001b[32mBuilding CXX object CMakeFiles/test_gpu.dir/tests/random.cpp.o\u001b[0m\n",
"[100%] \u001b[32m\u001b[1mLinking CXX executable test_gpu\u001b[0m\n",
"[100%] Built target test_gpu\n" "[100%] Built target test_gpu\n"
] ]
} }
@ -169,7 +131,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 9,
"id": "31761a31", "id": "31761a31",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -187,7 +149,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 10,
"id": "f53316f6", "id": "f53316f6",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -205,7 +167,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 11,
"id": "084f1e49", "id": "084f1e49",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -216,29 +178,20 @@
"running install\n", "running install\n",
"running bdist_egg\n", "running bdist_egg\n",
"running egg_info\n", "running egg_info\n",
"creating warpctc_pytorch.egg-info\n",
"writing warpctc_pytorch.egg-info/PKG-INFO\n", "writing warpctc_pytorch.egg-info/PKG-INFO\n",
"writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n", "writing dependency_links to warpctc_pytorch.egg-info/dependency_links.txt\n",
"writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n", "writing top-level names to warpctc_pytorch.egg-info/top_level.txt\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n", "writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"writing manifest file 'warpctc_pytorch.egg-info/SOURCES.txt'\n",
"installing library code to build/bdist.linux-x86_64/egg\n", "installing library code to build/bdist.linux-x86_64/egg\n",
"running install_lib\n", "running install_lib\n",
"running build_py\n", "running build_py\n",
"creating build\n",
"creating build/lib.linux-x86_64-3.9\n",
"creating build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
"copying warpctc_pytorch/__init__.py -> build/lib.linux-x86_64-3.9/warpctc_pytorch\n",
"running build_ext\n", "running build_ext\n",
"building 'warpctc_pytorch._warp_ctc' extension\n", "building 'warpctc_pytorch._warp_ctc' extension\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9\n",
"creating /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src\n",
"Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n", "Emitting ninja build file /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/build.ninja...\n",
"Compiling objects...\n", "Compiling objects...\n",
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
"[1/1] c++ -MMD -MF /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o.d -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -O2 -isystem /workspace/zhanghui/DeepSpeech-2.x/tools/venv/include -fPIC -I/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/TH -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/workspace/zhanghui/DeepSpeech-2.x/tools/venv/include/python3.9 -c -c /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/src/binding.cpp -o /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -std=c++14 -fPIC -DWARPCTC_ENABLE_GPU -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1011\"' -DTORCH_EXTENSION_NAME=_warp_ctc -D_GLIBCXX_USE_CXX11_ABI=0\n", "ninja: no work to do.\n",
"g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n", "g++ -pthread -B /workspace/zhanghui/DeepSpeech-2.x/tools/venv/compiler_compat -Wl,--sysroot=/ -shared -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -Wl,-rpath-link,/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib /workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/pytorch_binding/build/temp.linux-x86_64-3.9/src/binding.o -L/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build -L/workspace/zhanghui/DeepSpeech-2.x/tools/venv/lib/python3.9/site-packages/torch/lib -L/usr/local/cuda/lib64 -lwarpctc -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-3.9/warpctc_pytorch/_warp_ctc.cpython-39-x86_64-linux-gnu.so -Wl,-rpath,/workspace/zhanghui/DeepSpeech-2.x/docs/topic/ctc/warp-ctc/build\n",
"creating build/bdist.linux-x86_64\n",
"creating build/bdist.linux-x86_64/egg\n", "creating build/bdist.linux-x86_64/egg\n",
"creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n", "creating build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
"copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n", "copying build/lib.linux-x86_64-3.9/warpctc_pytorch/__init__.py -> build/bdist.linux-x86_64/egg/warpctc_pytorch\n",
@ -254,7 +207,6 @@
"writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n", "writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt\n",
"zip_safe flag not set; analyzing archive contents...\n", "zip_safe flag not set; analyzing archive contents...\n",
"warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n", "warpctc_pytorch.__pycache__._warp_ctc.cpython-39: module references __file__\n",
"creating dist\n",
"creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n", "creating 'dist/warpctc_pytorch-0.1-py3.9-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n",
"removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n", "removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n",
"Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n", "Processing warpctc_pytorch-0.1-py3.9-linux-x86_64.egg\n",
@ -275,7 +227,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 12,
"id": "ee4ca9e3", "id": "ee4ca9e3",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -293,7 +245,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 13,
"id": "59255ed8", "id": "59255ed8",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -311,21 +263,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 22,
"id": "1dae09b9", "id": "1dae09b9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"grep: warning: GREP_OPTIONS is deprecated; please use an alias or script\n"
]
}
],
"source": [ "source": [
"import torch\n", "import torch\n",
"import torch.nn as nn\n", "import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import warpctc_pytorch as wp\n", "import warpctc_pytorch as wp\n",
"import paddle.nn as pn\n", "import paddle.nn as pn\n",
"import paddle" "import paddle"
@ -333,7 +278,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 15,
"id": "83d0762e", "id": "83d0762e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -343,7 +288,7 @@
"'1.10.0+cu102'" "'1.10.0+cu102'"
] ]
}, },
"execution_count": 16, "execution_count": 15,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -354,17 +299,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 16,
"id": "62501e2c", "id": "62501e2c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'2.2.0'" "'2.2.1'"
] ]
}, },
"execution_count": 17, "execution_count": 16,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -375,7 +320,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 17,
"id": "9e8e0f40", "id": "9e8e0f40",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -392,6 +337,7 @@
} }
], ],
"source": [ "source": [
"# warpctc_pytorch CTCLoss\n",
"probs = torch.FloatTensor([[\n", "probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n", " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n", " ]]).transpose(0, 1).contiguous()\n",
@ -412,7 +358,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 18,
"id": "2cd46569", "id": "2cd46569",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -428,6 +374,7 @@
} }
], ],
"source": [ "source": [
"# pytorch CTCLoss\n",
"probs = torch.FloatTensor([[\n", "probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n", " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n", " ]]).transpose(0, 1).contiguous()\n",
@ -449,7 +396,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 27,
"id": "85c3461a", "id": "85c3461a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -467,6 +414,7 @@
} }
], ],
"source": [ "source": [
"# Paddle CTCLoss\n",
"paddle.set_device('cpu')\n", "paddle.set_device('cpu')\n",
"probs = paddle.to_tensor([[\n", "probs = paddle.to_tensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n", " [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1],\n",
@ -490,7 +438,55 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "d390cd91", "id": "8cdf76c2",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2c305eaf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 1, 5])\n",
"2.4628584384918213\n",
"[[[ 0.17703117 -0.7081247 0.17703117 0.17703117 0.17703117]]\n",
"\n",
" [[ 0.17703117 0.17703117 -0.7081247 0.17703117 0.17703117]]]\n"
]
}
],
"source": [
"# warpctc_pytorch CTCLoss, log_softmax idempotent\n",
"probs = torch.FloatTensor([[\n",
" [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]\n",
" ]]).transpose(0, 1).contiguous()\n",
"print(probs.size())\n",
"labels = torch.IntTensor([1, 2])\n",
"label_sizes = torch.IntTensor([2])\n",
"probs_sizes = torch.IntTensor([2])\n",
"probs.requires_grad_(True)\n",
"bs = probs.size(1)\n",
"\n",
"ctc_loss = wp.CTCLoss(size_average=False, length_average=False)\n",
"\n",
"log_probs = torch.log_softmax(probs, axis=-1)\n",
"cost = ctc_loss(log_probs, labels, probs_sizes, label_sizes)\n",
"cost = cost.sum() / bs\n",
"print(cost.item())\n",
"cost.backward()\n",
"print(probs.grad.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "443336f0",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

@ -22,19 +22,17 @@ Authors
* qingenz123@126.com (Qingen ZHAO) 2022 * qingenz123@126.com (Qingen ZHAO) 2022
""" """
import os
import logging
import argparse import argparse
import xml.etree.ElementTree as et
import glob import glob
import json import json
from ami_splits import get_AMI_split import logging
import os
import xml.etree.ElementTree as et
from distutils.util import strtobool from distutils.util import strtobool
from dataio import ( from ami_splits import get_AMI_split
load_pkl, from dataio import load_pkl
save_pkl, ) from dataio import save_pkl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SAMPLERATE = 16000 SAMPLERATE = 16000

@ -12,28 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Make VoxCeleb1 trial of kaldi format Make VoxCeleb1 trial of kaldi format
this script creat the test trial from kaldi trial voxceleb1_test_v2.txt or official trial veri_test2.txt this script creat the test trial from kaldi trial voxceleb1_test_v2.txt or official trial veri_test2.txt
to kaldi trial format to kaldi trial format
""" """
import argparse import argparse
import codecs import codecs
import os import os
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--voxceleb_trial", parser.add_argument(
default="voxceleb1_test_v2", "--voxceleb_trial",
type=str, default="voxceleb1_test_v2",
help="VoxCeleb trial file. Default we use the kaldi trial voxceleb1_test_v2.txt") type=str,
parser.add_argument("--trial", help="VoxCeleb trial file. Default we use the kaldi trial voxceleb1_test_v2.txt"
default="data/test/trial", )
type=str, parser.add_argument(
help="Kaldi format trial file") "--trial",
default="data/test/trial",
type=str,
help="Kaldi format trial file")
args = parser.parse_args() args = parser.parse_args()
def main(voxceleb_trial, trial): def main(voxceleb_trial, trial):
""" """
VoxCeleb provide several trial file, which format is different with kaldi format. VoxCeleb provide several trial file, which format is different with kaldi format.
@ -58,7 +60,9 @@ def main(voxceleb_trial, trial):
""" """
print("Start convert the voxceleb trial to kaldi format") print("Start convert the voxceleb trial to kaldi format")
if not os.path.exists(voxceleb_trial): if not os.path.exists(voxceleb_trial):
raise RuntimeError("{} does not exist. Pleas input the correct file path".format(voxceleb_trial)) raise RuntimeError(
"{} does not exist. Pleas input the correct file path".format(
voxceleb_trial))
trial_dirname = os.path.dirname(trial) trial_dirname = os.path.dirname(trial)
if not os.path.exists(trial_dirname): if not os.path.exists(trial_dirname):
@ -66,9 +70,9 @@ def main(voxceleb_trial, trial):
with codecs.open(voxceleb_trial, 'r', encoding='utf-8') as f, \ with codecs.open(voxceleb_trial, 'r', encoding='utf-8') as f, \
codecs.open(trial, 'w', encoding='utf-8') as w: codecs.open(trial, 'w', encoding='utf-8') as w:
for line in f: for line in f:
target_or_nontarget, path1, path2 = line.strip().split() target_or_nontarget, path1, path2 = line.strip().split()
utt_id1 = "-".join(path1.split("/")) utt_id1 = "-".join(path1.split("/"))
utt_id2 = "-".join(path2.split("/")) utt_id2 = "-".join(path2.split("/"))
target = "nontarget" target = "nontarget"
@ -77,5 +81,6 @@ def main(voxceleb_trial, trial):
w.write("{} {} {}\n".format(utt_id1, utt_id2, target)) w.write("{} {} {}\n".format(utt_id1, utt_id2, target))
print("Convert the voxceleb trial to kaldi format successfully") print("Convert the voxceleb trial to kaldi format successfully")
if __name__ == "__main__": if __name__ == "__main__":
main(args.voxceleb_trial, args.trial) main(args.voxceleb_trial, args.trial)

@ -11,14 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

@ -413,7 +413,8 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool): def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("invalid sample rate, please input --sr 8000 or --sr 16000") logger.error(
"invalid sample rate, please input --sr 8000 or --sr 16000")
return False return False
if isinstance(audio_file, (str, os.PathLike)): if isinstance(audio_file, (str, os.PathLike)):

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from io import BytesIO from io import BytesIO
from typing import List
import numpy as np import numpy as np

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import uvicorn import uvicorn
import yaml
from fastapi import FastAPI from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.engine.engine_pool import init_engine_pool

@ -48,8 +48,9 @@ class TTSClientExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--input', '--input',
type=str, type=str,
default="你好,欢迎使用语音合成服务", default=None,
help='A sentence to be synthesized.') help='Text to be synthesized.',
required=True)
self.parser.add_argument( self.parser.add_argument(
'--spk_id', type=int, default=0, help='Speaker id') '--spk_id', type=int, default=0, help='Speaker id')
self.parser.add_argument( self.parser.add_argument(
@ -123,7 +124,7 @@ class TTSClientExecutor(BaseExecutor):
logger.info("RTF: %f " % (time_consume / duration)) logger.info("RTF: %f " % (time_consume / duration))
return True return True
except: except BaseException:
logger.error("Failed to synthesized audio.") logger.error("Failed to synthesized audio.")
return False return False
@ -163,7 +164,7 @@ class TTSClientExecutor(BaseExecutor):
print("Audio duration: %f s." % (duration)) print("Audio duration: %f s." % (duration))
print("Response time: %f s." % (time_consume)) print("Response time: %f s." % (time_consume))
print("RTF: %f " % (time_consume / duration)) print("RTF: %f " % (time_consume / duration))
except: except BaseException:
print("Failed to synthesized audio.") print("Failed to synthesized audio.")
@ -181,8 +182,9 @@ class ASRClientExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--input', '--input',
type=str, type=str,
default="./paddlespeech/server/tests/16_audio.wav", default=None,
help='Audio file to be recognized') help='Audio file to be recognized',
required=True)
self.parser.add_argument( self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate') '--sample_rate', type=int, default=16000, help='audio sample rate')
self.parser.add_argument( self.parser.add_argument(
@ -209,7 +211,7 @@ class ASRClientExecutor(BaseExecutor):
logger.info(r.json()) logger.info(r.json())
logger.info("time cost %f s." % (time_end - time_start)) logger.info("time cost %f s." % (time_end - time_start))
return True return True
except: except BaseException:
logger.error("Failed to speech recognition.") logger.error("Failed to speech recognition.")
return False return False
@ -240,5 +242,5 @@ class ASRClientExecutor(BaseExecutor):
time_end = time.time() time_end = time.time()
print(r.json()) print(r.json())
print("time cost %f s." % (time_end - time_start)) print("time cost %f s." % (time_end - time_start))
except: except BaseException:
print("Failed to speech recognition.") print("Failed to speech recognition.")

@ -41,7 +41,8 @@ class ServerExecutor(BaseExecutor):
"--config_file", "--config_file",
action="store", action="store",
help="yaml file of the app", help="yaml file of the app",
default="./conf/application.yaml") default=None,
required=True)
self.parser.add_argument( self.parser.add_argument(
"--log_file", "--log_file",

@ -5,4 +5,4 @@ cfg_path: # [optional]
ckpt_path: # [optional] ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
force_yes: True force_yes: True
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'

@ -15,7 +15,7 @@ decode_method:
force_yes: True force_yes: True
am_predictor_conf: am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: True
switch_ir_optim: True switch_ir_optim: True

@ -29,4 +29,4 @@ voc_stat:
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'

@ -15,7 +15,7 @@ speaker_dict:
spk_id: 0 spk_id: 0
am_predictor_conf: am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: False enable_mkldnn: False
switch_ir_optim: False switch_ir_optim: False
@ -30,7 +30,7 @@ voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 #must match the model voc_sample_rate: 24000 #must match the model
voc_predictor_conf: voc_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
enable_mkldnn: False enable_mkldnn: False
switch_ir_optim: False switch_ir_optim: False

@ -13,31 +13,24 @@
# limitations under the License. # limitations under the License.
import io import io
import os import os
from typing import List
from typing import Optional from typing import Optional
from typing import Union
import librosa
import paddle import paddle
import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine'] __all__ = ['ASREngine']
pretrained_models = { pretrained_models = {
"deepspeech2offline_aishell-zh-16k": { "deepspeech2offline_aishell-zh-16k": {
'url': 'url':
@ -143,7 +136,6 @@ class ASRServerExecutor(ASRExecutor):
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None)) grad_norm_type=self.config.get('ctc_grad_norm_type', None))
@paddle.no_grad() @paddle.no_grad()
def infer(self, model_type: str): def infer(self, model_type: str):
""" """
@ -161,9 +153,8 @@ class ASRServerExecutor(ASRExecutor):
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch) cfg.num_proc_bsearch)
output_data = run_model( output_data = run_model(self.am_predictor,
self.am_predictor, [audio.numpy(), audio_len.numpy()])
[audio.numpy(), audio_len.numpy()])
probs = output_data[0] probs = output_data[0]
eouts_len = output_data[1] eouts_len = output_data[1]
@ -208,14 +199,14 @@ class ASREngine(BaseEngine):
paddle.set_device(paddle.get_device()) paddle.set_device(paddle.get_device())
self.executor._init_from_path( self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
am_model=self.config.am_model, am_model=self.config.am_model,
am_params=self.config.am_params, am_params=self.config.am_params,
lang=self.config.lang, lang=self.config.lang,
sample_rate=self.config.sample_rate, sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path, cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method, decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf) am_predictor_conf=self.config.am_predictor_conf)
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
@ -230,7 +221,8 @@ class ASREngine(BaseEngine):
io.BytesIO(audio_data), self.config.sample_rate, io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes): self.config.force_yes):
logger.info("start running asr engine") logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type, io.BytesIO(audio_data)) self.executor.preprocess(self.config.model_type,
io.BytesIO(audio_data))
self.executor.infer(self.config.model_type) self.executor.infer(self.config.model_type)
self.output = self.executor.postprocess() # Retrieve result of asr. self.output = self.executor.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine") logger.info("end inferring asr engine")

@ -53,7 +53,10 @@ class ASREngine(BaseEngine):
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
self.config = get_config(config_file) self.config = get_config(config_file)
paddle.set_device(self.config.device) if self.config.device is None:
paddle.set_device(paddle.get_device())
else:
paddle.set_device(self.config.device)
self.executor._init_from_path( self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate, self.config.model, self.config.lang, self.config.sample_rate,
self.config.cfg_path, self.config.decode_method, self.config.cfg_path, self.config.decode_method,

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from typing import Any
from typing import List
from typing import Union from typing import Union
from pattern_singleton import Singleton from pattern_singleton import Singleton

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from typing import Text from typing import Text
__all__ = ['EngineFactory'] __all__ = ['EngineFactory']

@ -29,8 +29,10 @@ def init_engine_pool(config) -> bool:
""" """
global ENGINE_POOL global ENGINE_POOL
for engine in config.engine_backend: for engine in config.engine_backend:
ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine]) ENGINE_POOL[engine] = EngineFactory.get_engine(
if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]): engine_name=engine, engine_type=config.engine_type[engine])
if not ENGINE_POOL[engine].init(
config_file=config.engine_backend[engine]):
return False return False
return True return True

@ -360,8 +360,8 @@ class TTSEngine(BaseEngine):
am_predictor_conf=self.config.am_predictor_conf, am_predictor_conf=self.config.am_predictor_conf,
voc_predictor_conf=self.config.voc_predictor_conf, ) voc_predictor_conf=self.config.voc_predictor_conf, )
except: except BaseException:
logger.info("Initialize TTS server engine Failed.") logger.error("Initialize TTS server engine Failed.")
return False return False
logger.info("Initialize TTS server engine successfully.") logger.info("Initialize TTS server engine successfully.")
@ -405,11 +405,13 @@ class TTSEngine(BaseEngine):
# transform speed # transform speed
try: # windows not support soxbindings try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs) wav_speed = change_speed(wav_vol, speed, target_fs)
except: except ServerBaseException:
raise ServerBaseException( raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
"Transform speed failed. Can not install soxbindings on your system. \ "Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.") You need to set speed value 1.0.")
except BaseException:
logger.error("Transform speed failed.")
# wav to base64 # wav to base64
buf = io.BytesIO() buf = io.BytesIO()
@ -462,9 +464,11 @@ class TTSEngine(BaseEngine):
try: try:
self.executor.infer( self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
except: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.") "tts infer failed.")
except BaseException:
logger.error("tts infer failed.")
try: try:
target_sample_rate, wav_base64 = self.postprocess( target_sample_rate, wav_base64 = self.postprocess(
@ -474,8 +478,10 @@ class TTSEngine(BaseEngine):
volume=volume, volume=volume,
speed=speed, speed=speed,
audio_path=save_path) audio_path=save_path)
except: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
except BaseException:
logger.error("tts postprocess failed.")
return lang, target_sample_rate, wav_base64 return lang, target_sample_rate, wav_base64

@ -54,7 +54,10 @@ class TTSEngine(BaseEngine):
try: try:
self.config = get_config(config_file) self.config = get_config(config_file)
paddle.set_device(self.config.device) if self.config.device is None:
paddle.set_device(paddle.get_device())
else:
paddle.set_device(self.config.device)
self.executor._init_from_path( self.executor._init_from_path(
am=self.config.am, am=self.config.am,
@ -69,8 +72,8 @@ class TTSEngine(BaseEngine):
voc_ckpt=self.config.voc_ckpt, voc_ckpt=self.config.voc_ckpt,
voc_stat=self.config.voc_stat, voc_stat=self.config.voc_stat,
lang=self.config.lang) lang=self.config.lang)
except: except BaseException:
logger.info("Initialize TTS server engine Failed.") logger.error("Initialize TTS server engine Failed.")
return False return False
logger.info("Initialize TTS server engine successfully.") logger.info("Initialize TTS server engine successfully.")
@ -114,10 +117,13 @@ class TTSEngine(BaseEngine):
# transform speed # transform speed
try: # windows not support soxbindings try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs) wav_speed = change_speed(wav_vol, speed, target_fs)
except: except ServerBaseException:
raise ServerBaseException( raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
"Can not install soxbindings on your system.") "Transform speed failed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
logger.error("Transform speed failed.")
# wav to base64 # wav to base64
buf = io.BytesIO() buf = io.BytesIO()
@ -170,9 +176,11 @@ class TTSEngine(BaseEngine):
try: try:
self.executor.infer( self.executor.infer(
text=sentence, lang=lang, am=self.config.am, spk_id=spk_id) text=sentence, lang=lang, am=self.config.am, spk_id=spk_id)
except: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts infer failed.") "tts infer failed.")
except BaseException:
logger.error("tts infer failed.")
try: try:
target_sample_rate, wav_base64 = self.postprocess( target_sample_rate, wav_base64 = self.postprocess(
@ -182,8 +190,10 @@ class TTSEngine(BaseEngine):
volume=volume, volume=volume,
speed=speed, speed=speed,
audio_path=save_path) audio_path=save_path)
except: except ServerBaseException:
raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR, raise ServerBaseException(ErrorCode.SERVER_INTERNAL_ERR,
"tts postprocess failed.") "tts postprocess failed.")
except BaseException:
logger.error("tts postprocess failed.")
return lang, target_sample_rate, wav_base64 return lang, target_sample_rate, wav_base64

@ -14,6 +14,7 @@
import base64 import base64
import traceback import traceback
from typing import Union from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
@ -83,7 +84,7 @@ def asr(request_body: ASRRequest):
except ServerBaseException as e: except ServerBaseException as e:
response = failed_response(e.error_code, e.msg) response = failed_response(e.error_code, e.msg)
except: except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc() traceback.print_exc()

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel

@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ['ASRResponse', 'TTSResponse'] __all__ = ['ASRResponse', 'TTSResponse']

@ -114,7 +114,7 @@ def tts(request_body: TTSRequest):
} }
except ServerBaseException as e: except ServerBaseException as e:
response = failed_response(e.error_code, e.msg) response = failed_response(e.error_code, e.msg)
except: except BaseException:
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
traceback.print_exc() traceback.print_exc()

@ -10,11 +10,11 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the # See the License for the
import requests import base64
import json import json
import time import time
import base64
import io import requests
def readwav2base64(wav_file): def readwav2base64(wav_file):
@ -34,23 +34,23 @@ def main():
url = "http://127.0.0.1:8090/paddlespeech/asr" url = "http://127.0.0.1:8090/paddlespeech/asr"
# start Timestamp # start Timestamp
time_start=time.time() time_start = time.time()
test_audio_dir = "./16_audio.wav" test_audio_dir = "./16_audio.wav"
audio = readwav2base64(test_audio_dir) audio = readwav2base64(test_audio_dir)
data = { data = {
"audio": audio, "audio": audio,
"audio_format": "wav", "audio_format": "wav",
"sample_rate": 16000, "sample_rate": 16000,
"lang": "zh_cn", "lang": "zh_cn",
} }
r = requests.post(url=url, data=json.dumps(data)) r = requests.post(url=url, data=json.dumps(data))
# ending Timestamp # ending Timestamp
time_end=time.time() time_end = time.time()
print('time cost',time_end - time_start, 's') print('time cost', time_end - time_start, 's')
print(r.json()) print(r.json())

@ -25,6 +25,7 @@ import soundfile
from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.audio_process import wav2pcm
# Request and response # Request and response
def tts_client(args): def tts_client(args):
""" Request and response """ Request and response
@ -99,5 +100,5 @@ if __name__ == "__main__":
print("Inference time: %f" % (time_consume)) print("Inference time: %f" % (time_consume))
print("The duration of synthesized audio: %f" % (duration)) print("The duration of synthesized audio: %f" % (duration))
print("The RTF is: %f" % (rtf)) print("The RTF is: %f" % (rtf))
except: except BaseException:
print("Failed to synthesized audio.") print("Failed to synthesized audio.")

@ -219,7 +219,7 @@ class ConfigCache:
try: try:
cfg = yaml.load(file, Loader=yaml.FullLoader) cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg) self._data.update(cfg)
except: except BaseException:
self.flush() self.flush()
@property @property

@ -258,4 +258,4 @@ class ChainDataset(Dataset):
return dataset[i] return dataset[i]
i -= len(dataset) i -= len(dataset)
raise IndexError("dataset index out of range") raise IndexError("dataset index out of range")

@ -1,48 +0,0 @@
ConfigArgParse
coverage
editdistance
g2p_en
g2pM
gpustat
h5py
inflect
jieba
jsonlines
kaldiio
librosa
loguru
matplotlib
nara_wpe
nltk
paddleaudio
paddlenlp
paddlespeech_ctcdecoders
paddlespeech_feat
pandas
phkit
Pillow
praatio==5.0.0
pre-commit
pybind11
pypi-kenlm
pypinyin
python-dateutil
pyworld
resampy==0.2.2
sacrebleu
scipy
sentencepiece~=0.1.96
snakeviz
soundfile~=0.10
sox
soxbindings
textgrid
timer
tqdm
typeguard
unidecode
visualdl
webrtcvad
yacs~=0.1.8
yq
zhon

@ -27,47 +27,53 @@ from setuptools.command.install import install
HERE = Path(os.path.abspath(os.path.dirname(__file__))) HERE = Path(os.path.abspath(os.path.dirname(__file__)))
VERSION = '0.1.1' VERSION = '0.1.2'
base = [
"editdistance",
"g2p_en",
"g2pM",
"h5py",
"inflect",
"jieba",
"jsonlines",
"kaldiio",
"librosa==0.8.1",
"loguru",
"matplotlib",
"nara_wpe",
"pandas",
"paddleaudio",
"paddlenlp",
"paddlespeech_feat",
"praatio==5.0.0",
"pypinyin",
"python-dateutil",
"pyworld",
"resampy==0.2.2",
"sacrebleu",
"scipy",
"sentencepiece~=0.1.96",
"soundfile~=0.10",
"textgrid",
"timer",
"tqdm",
"typeguard",
"visualdl",
"webrtcvad",
"yacs~=0.1.8",
]
server = [
"fastapi",
"uvicorn",
"pattern_singleton",
"prettytable",
]
requirements = { requirements = {
"install": [ "install":
"editdistance", base + server,
"g2p_en",
"g2pM",
"h5py",
"inflect",
"jieba",
"jsonlines",
"kaldiio",
"librosa",
"loguru",
"matplotlib",
"nara_wpe",
"pandas",
"paddleaudio",
"paddlenlp",
"paddlespeech_feat",
"praatio==5.0.0",
"pypinyin",
"python-dateutil",
"pyworld",
"resampy==0.2.2",
"sacrebleu",
"scipy",
"sentencepiece~=0.1.96",
"soundfile~=0.10",
"textgrid",
"timer",
"tqdm",
"typeguard",
"visualdl",
"webrtcvad",
"yacs~=0.1.8",
# fastapi server
"fastapi",
"uvicorn",
"prettytable"
],
"develop": [ "develop": [
"ConfigArgParse", "ConfigArgParse",
"coverage", "coverage",

@ -23,10 +23,11 @@ Credits
This code is adapted from https://github.com/nryant/dscore This code is adapted from https://github.com/nryant/dscore
""" """
import argparse import argparse
from distutils.util import strtobool
import os import os
import re import re
import subprocess import subprocess
from distutils.util import strtobool
import numpy as np import numpy as np
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")

Loading…
Cancel
Save