Merge pull request #2524 from zh794390558/u2

[speechx] add u2/u2pp asr inference
pull/2585/head
YangZhou 2 years ago committed by GitHub
commit bbf2401e3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,13 +50,20 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
#- id: copyright_checker #- id: copyright_checker
# name: copyright_checker # name: copyright_checker
# entry: python .pre-commit-hooks/copyright-check.hook # entry: python .pre-commit-hooks/copyright-check.hook
# language: system # language: system
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
# exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ # exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: cpplint
name: cpplint
description: Static code analysis of C/C++ files
language: python
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -42,6 +42,7 @@ for type in attention_rescoring; do
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
mkdir -p ${output_dir} mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test_wav.py \ python3 -u ${BIN_DIR}/test_wav.py \
--debug True \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--decode_cfg ${decode_config_path} \ --decode_cfg ${decode_config_path} \

@ -16,6 +16,8 @@ import os
import sys import sys
from pathlib import Path from pathlib import Path
import distutils
import numpy as np
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
@ -74,6 +76,8 @@ class U2Infer():
# fbank # fbank
feat = self.preprocessing(audio, **self.preprocess_args) feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}") logger.info(f"feat shape: {feat.shape}")
if self.args.debug:
np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
@ -126,6 +130,11 @@ if __name__ == "__main__":
"--result_file", type=str, help="path of save the asr result") "--result_file", type=str, help="path of save the asr result")
parser.add_argument( parser.add_argument(
"--audio_file", type=str, help="path of the input audio file") "--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)

@ -75,6 +75,7 @@ base = [
"braceexpand", "braceexpand",
"pyyaml", "pyyaml",
"pybind11", "pybind11",
"paddleslim==2.3.4",
] ]
server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"] server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"]

@ -0,0 +1,29 @@
# This file is used by clang-format to autoformat paddle source code
#
# The clang-format is part of llvm toolchain.
# It need to install llvm and clang to format source code style.
#
# The basic usage is,
# clang-format -i -style=file PATH/TO/SOURCE/CODE
#
# The -style=file implicit use ".clang-format" file located in one of
# parent directory.
# The -i means inplace change.
#
# The document of clang-format is
# http://clang.llvm.org/docs/ClangFormat.html
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
---
Language: Cpp
BasedOnStyle: Google
IndentWidth: 4
TabWidth: 4
ContinuationIndentWidth: 4
MaxEmptyLinesToKeep: 2
AccessModifierOffset: -2 # The private/protected/public has no indent in class
Standard: Cpp11
AllowAllParametersOfDeclarationOnNextLine: true
BinPackParameters: false
BinPackArguments: false
...

@ -1 +1,2 @@
tools/valgrind* tools/valgrind*
*log

@ -13,7 +13,6 @@ set(CMAKE_CXX_STANDARD 14)
set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake) set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)
# Modules # Modules
list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}/external)
list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}) list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir})
include(FetchContent) include(FetchContent)
include(ExternalProject) include(ExternalProject)
@ -32,9 +31,13 @@ SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall
############################################################################### ###############################################################################
# Option Configurations # Option Configurations
############################################################################### ###############################################################################
# option configurations
option(TEST_DEBUG "option for debug" OFF) option(TEST_DEBUG "option for debug" OFF)
option(USE_PROFILING "enable c++ profling" OFF)
option(USING_U2 "compile u2 model." ON)
option(USING_DS2 "compile with ds2 model." ON)
option(USING_GPU "u2 compute on GPU." OFF)
############################################################################### ###############################################################################
# Include third party # Include third party
@ -83,48 +86,65 @@ add_dependencies(openfst gflags glog)
# paddle lib # paddle lib
set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib) include(paddleinference)
set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
ExternalProject_Add(paddle
URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz # paddle core.so
URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873 find_package(Threads REQUIRED)
PREFIX ${paddle_PREFIX_DIR} find_package(PythonLibs REQUIRED)
SOURCE_DIR ${paddle_SOURCE_DIR} find_package(Python3 REQUIRED)
CONFIGURE_COMMAND "" find_package(pybind11 CONFIG)
BUILD_COMMAND ""
INSTALL_COMMAND "" message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
) message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
message(STATUS "Pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}, pybind11_LIBRARIES=${pybind11_LIBRARIES}, pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
set(PADDLE_LIB ${fc_patch}/paddle-lib)
include_directories("${PADDLE_LIB}/paddle/include") # paddle include and link option
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") # -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") execute_process(
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include") COMMAND python -c "\
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include") import os;\
import paddle;\
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") include_dir=paddle.sysconfig.get_include();\
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") paddle_dir=os.path.split(include_dir)[0];\
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib") libs_dir=os.path.join(paddle_dir, 'libs');\
link_directories("${PADDLE_LIB}/paddle/lib") fluid_dir=os.path.join(paddle_dir, 'fluid');\
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib") out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir]);\
out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out);\
##paddle with mkl "
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") OUTPUT_VARIABLE PADDLE_LINK_FLAGS
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml") RESULT_VARIABLE SUCESS)
include_directories("${MATH_LIB_PATH}/include")
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include") # paddle compile option
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) # -I/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/include
set(EXTERNAL_LIB "-lrt -ldl -lpthread") execute_process(
COMMAND python -c "\
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) import paddle; \
set(DEPS ${DEPS} include_dir = paddle.sysconfig.get_include(); \
${MATH_LIB} ${MKLDNN_LIB} print(f\"-I{include_dir}\"); \
glog gflags protobuf xxhash cryptopp "
${EXTERNAL_LIB}) OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
# for LD_LIBRARY_PATH
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
execute_process(
COMMAND python -c " \
import os; \
import paddle; \
include_dir=paddle.sysconfig.get_include(); \
paddle_dir=os.path.split(include_dir)[0]; \
libs_dir=os.path.join(paddle_dir, 'libs'); \
fluid_dir=os.path.join(paddle_dir, 'fluid'); \
out=':'.join([libs_dir, fluid_dir]); print(out); \
"
OUTPUT_VARIABLE PADDLE_LIB_DIRS)
message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
############################################################################### ###############################################################################

@ -3,11 +3,14 @@
## Environment ## Environment
We develop under: We develop under:
* python - 3.7
* docker - `registry.baidubce.com/paddlepaddle/paddle:2.2.2-gpu-cuda10.2-cudnn7` * docker - `registry.baidubce.com/paddlepaddle/paddle:2.2.2-gpu-cuda10.2-cudnn7`
* os - Ubuntu 16.04.7 LTS * os - Ubuntu 16.04.7 LTS
* gcc/g++/gfortran - 8.2.0 * gcc/g++/gfortran - 8.2.0
* cmake - 3.16.0 * cmake - 3.16.0
> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx.
> We make sure all things work fun under docker, and recommend using it to develop and deploy. > We make sure all things work fun under docker, and recommend using it to develop and deploy.
* [How to Install Docker](https://docs.docker.com/engine/install/) * [How to Install Docker](https://docs.docker.com/engine/install/)
@ -24,16 +27,23 @@ docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --nam
* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html). * More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
2. Create python environment.
2. Build `speechx` and `examples`. ```
bash tools/venv.sh
```
> Do not source venv. 2. Build `speechx` and `examples`.
For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version.
For example:
``` ```
pushd /path/to/speechx source venv/bin/activate
python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html
./build.sh ./build.sh
``` ```
3. Go to `examples` to have a fun. 3. Go to `examples` to have a fun.
More details please see `README.md` under `examples`. More details please see `README.md` under `examples`.

@ -1,4 +1,5 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -xe
# the build script had verified in the paddlepaddle docker image. # the build script had verified in the paddlepaddle docker image.
# please follow the instruction below to install PaddlePaddle image. # please follow the instruction below to install PaddlePaddle image.
@ -17,11 +18,6 @@ fi
#rm -rf build #rm -rf build
mkdir -p build mkdir -p build
cd build
cmake .. -DBOOST_ROOT:STRING=${boost_SOURCE_DIR} cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
#cmake .. cmake --build build -j
make -j
cd -

@ -1,12 +0,0 @@
include(FetchContent)
FetchContent_Declare(
gflags
URL https://github.com/gflags/gflags/archive/v2.2.1.zip
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
)
FetchContent_MakeAvailable(gflags)
# openfst need
include_directories(${gflags_BINARY_DIR}/include)

@ -0,0 +1,11 @@
include(FetchContent)
FetchContent_Declare(
gflags
URL https://github.com/gflags/gflags/archive/v2.2.2.zip
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
)
FetchContent_MakeAvailable(gflags)
# openfst need
include_directories(${gflags_BINARY_DIR}/include)

@ -1,8 +1,8 @@
include(FetchContent) include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
gtest gtest
URL https://github.com/google/googletest/archive/release-1.10.0.zip URL https://github.com/google/googletest/archive/release-1.11.0.zip
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91 URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
) )
FetchContent_MakeAvailable(gtest) FetchContent_MakeAvailable(gtest)

@ -1,7 +1,7 @@
include(FetchContent) include(FetchContent)
set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src) set(OpenBLAS_SOURCE_DIR ${fc_patch}/openblas-src)
set(OpenBLAS_PREFIX ${fc_patch}/OpenBLAS-prefix) set(OpenBLAS_PREFIX ${fc_patch}/openblas-prefix)
# ###################################################################################################################### # ######################################################################################################################
# OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575 # OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575
@ -43,6 +43,7 @@ ExternalProject_Add(
# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition # https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
ExternalProject_Get_Property(OPENBLAS INSTALL_DIR) ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
message(STATUS "OPENBLAS install dir: ${INSTALL_DIR}")
set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR}) set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})
add_library(openblas STATIC IMPORTED) add_library(openblas STATIC IMPORTED)
add_dependencies(openblas OPENBLAS) add_dependencies(openblas OPENBLAS)
@ -55,4 +56,6 @@ set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_INSTALL_P
# ${CMAKE_INSTALL_LIBDIR} lib # ${CMAKE_INSTALL_LIBDIR} lib
# ${CMAKE_INSTALL_INCLUDEDIR} include # ${CMAKE_INSTALL_INCLUDEDIR} include
link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}) # include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
# fix for can not find `cblas.h`
include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/openblas)

@ -0,0 +1,49 @@
set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
include(FetchContent)
FetchContent_Declare(
paddle
URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
PREFIX ${paddle_PREFIX_DIR}
SOURCE_DIR ${paddle_SOURCE_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)
FetchContent_MakeAvailable(paddle)
set(PADDLE_LIB_THIRD_PARTY_PATH "${paddle_SOURCE_DIR}/third_party/install/")
include_directories("${paddle_SOURCE_DIR}/paddle/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${paddle_SOURCE_DIR}/paddle/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn/lib")
##paddle with mkl
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
include_directories("${MATH_LIB_PATH}/include")
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
# global vars
set(DEPS ${paddle_SOURCE_DIR}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX} CACHE INTERNAL "deps")
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf xxhash cryptopp
${EXTERNAL_LIB} CACHE INTERNAL "deps")
message(STATUS "Deps libraries: ${DEPS}")

@ -1,20 +1,42 @@
# Examples for SpeechX # Examples for SpeechX
> `u2pp_ol` is recommended.
* `u2pp_ol` - u2++ streaming asr test under `aishell-1` test dataset.
* `ds2_ol` - ds2 streaming test under `aishell-1` test dataset. * `ds2_ol` - ds2 streaming test under `aishell-1` test dataset.
## How to run ## How to run
`run.sh` is the entry point. ### Create env
Using `tools/evn.sh` under `speechx` to create python env.
```
bash tools/env.sh
```
Source env before play with example.
```
. venv/bin/activate
```
### Play with example
`run.sh` is the entry point for every example.
Example to play `ds2_ol`: Example to play `u2pp_ol`:
``` ```
pushd ds2_ol/aishell pushd u2pp_ol/wenetspeech
bash run.sh bash run.sh --stop_stage 4
``` ```
## Display Model with [Netron](https://github.com/lutzroeder/netron) ## Display Model with [Netron](https://github.com/lutzroeder/netron)
If you have a model, we can using this commnd to show model graph.
For example:
``` ```
pip install netron pip install netron
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20 netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20

@ -1,8 +1,9 @@
# Codelab # Codelab
## introduction > The below is for developing and offline testing.
> Do not run it only if you know what it is.
> The below is for developing and offline testing. Do not run it only if you know what it is.
* nnet * nnet
* feat * feat
* decoder * decoder
* u2

@ -69,7 +69,7 @@ compute_linear_spectrogram_main \
echo "compute linear spectrogram feature." echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming # run ctc beam search decoder as streaming
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--result_wspecifier=ark,t:$exp_dir/result.txt \ --result_wspecifier=ark,t:$exp_dir/result.txt \
--feature_rspecifier=ark:$feat_wspecifier \ --feature_rspecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \

@ -1,12 +1,12 @@
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../ SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C

@ -42,8 +42,8 @@ mkdir -p $exp_dir
export GLOG_logtostderr=1 export GLOG_logtostderr=1
cmvn_json2kaldi_main \ cmvn_json2kaldi_main \
--json_file $model_dir/data/mean_std.json \ --json_file=$model_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \ --cmvn_write_path=$exp_dir/cmvn.ark \
--binary=false --binary=false
echo "convert json cmvn to kaldi ark." echo "convert json cmvn to kaldi ark."
@ -54,4 +54,10 @@ compute_linear_spectrogram_main \
--cmvn_file=$exp_dir/cmvn.ark --cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature." echo "compute linear spectrogram feature."
compute_fbank_main \
--num_bins=161 \
--wav_rspecifier=scp:$data_dir/wav.scp \
--feature_wspecifier=ark,t:$exp_dir/fbank.ark \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute fbank feature."

@ -6,7 +6,7 @@ SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C

@ -0,0 +1 @@
# u2/u2pp Streaming Test

@ -0,0 +1,22 @@
#!/bin/bash
set +x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=$data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
ctc_prefix_beam_search_decoder_main \
--model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--result_wspecifier=ark,t:$exp/result.ark
echo "u2 ctc prefix beam search decode."

@ -0,0 +1,27 @@
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
compute_fbank_main \
--num_bins 80 \
--wav_rspecifier=scp:$data/wav.scp \
--cmvn_file=$exp/cmvn.ark \
--feature_wspecifier=ark,t:$exp/fbank.ark
echo "compute fbank feature."

@ -0,0 +1,23 @@
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_nnet_main \
--model_path=$model_dir/export.jit \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
echo "u2 nnet decode."

@ -0,0 +1,22 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$exp/cmvn.ark \
--model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \
--wav_rspecifier=scp:$data/wav.scp \
--result_wspecifier=ark,t:$exp/result.ark

@ -0,0 +1,18 @@
# This contains the locations of binarys build required for running the examples.
unset GREP_OPTIONS
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer
PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH

@ -0,0 +1,43 @@
#!/bin/bash
set -x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -f data/model/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
./local/feat.sh
./local/nnet.sh
./local/decode.sh

@ -1,5 +1,5 @@
#!/bin/bash #!/bin/bash
set +x set -x
set -e set -e
. path.sh . path.sh
@ -11,7 +11,7 @@ stop_stage=100
. utils/parse_options.sh . utils/parse_options.sh
# 1. compile # 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then if [ ! -d ${SPEECHX_BUILD} ]; then
pushd ${SPEECHX_ROOT} pushd ${SPEECHX_ROOT}
bash build.sh bash build.sh
popd popd
@ -84,7 +84,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# recognizer # recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
@ -103,7 +103,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with lm # decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
@ -135,7 +135,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
tlg_decoder_main \ ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \

@ -84,7 +84,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# recognizer # recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \
@ -102,7 +102,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with lm # decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \
@ -133,7 +133,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
tlg_decoder_main \ ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \

@ -0,0 +1,5 @@
# U2/U2++ Streaming ASR
## Examples
* `wenetspeech` - Streaming Decoding with wenetspeech u2/u2++ model. Using aishell test data for testing.

@ -0,0 +1,28 @@
# u2/u2pp Streaming ASR
## Testing with Aishell Test Data
## Download wav and model
```
run.sh --stop_stage 0
```
### compute feature
```
./run.sh --stage 1 --stop_stage 1
```
### decoding using feature
```
./run.sh --stage 2 --stop_stage 2
```
### decoding using wav
```
./run.sh --stage 3 --stop_stage 3
```

@ -0,0 +1,71 @@
#!/bin/bash
# To be run from one directory above this script.
. ./path.sh
nj=40
text=data/local/lm/text
lexicon=data/local/dict/lexicon.txt
for f in "$text" "$lexicon"; do
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
done
# Check SRILM tools
if ! which ngram-count > /dev/null; then
echo "srilm tools are not found, please download it and install it from: "
echo "http://www.speech.sri.com/projects/srilm/download.html"
echo "Then add the tools to your PATH"
exit 1
fi
# This script takes no arguments. It assumes you have already run
# aishell_data_prep.sh.
# It takes as input the files
# data/local/lm/text
# data/local/dict/lexicon.txt
dir=data/local/lm
mkdir -p $dir
cleantext=$dir/text.no_oov
# oov to <SPOKEN_NOISE>
# lexicon line: word char0 ... charn
# text line: utt word0 ... wordn -> line: <SPOKEN_NOISE> word0 ... wordn
text_dir=$(dirname $text)
split_name=$(basename $text)
./local/split_data.sh $text_dir $text $split_name $nj
utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \
cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
\> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1;
cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext
# compute word counts, sort in descending order
# line: count word
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \
sort --parallel=`nproc` -nr > $dir/word.counts || exit 1;
# Get counts from acoustic training transcripts, and add one-count
# for each word in the lexicon (but not silence, we don't want it
# in the LM-- we'll add it optionally later).
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1;
# word with <s> </s>
cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
# hold out to compute ppl
heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results
mkdir -p $dir
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
head -$heldout_sent > $dir/heldout
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
tail -n +$heldout_sent > $dir/train
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
-map-unk "<UNK>" -kndiscount -interpolate -lm $dir/lm.arpa
ngram -lm $dir/lm.arpa -ppl $dir/heldout

@ -0,0 +1,25 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/decoder.fbank.wolm.log \
ctc_prefix_beam_search_decoder_main \
--model_path=$model_dir/export.jit \
--vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank.scp \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_decode.ark
cat $data/split${nj}/*/result_decode.ark > $exp/${label_file}
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
tail -n 7 $exp/${wer}

@ -0,0 +1,31 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
aishell_wav_scp=aishell_test.scp
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
compute_fbank_main \
--num_bins 80 \
--cmvn_file=$exp/cmvn.ark \
--streaming_chunk=36 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp
echo "compute fbank feature."

@ -0,0 +1,23 @@
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_nnet_main \
--model_path=$model_dir/export.jit \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
echo "u2 nnet decode."

@ -0,0 +1,37 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
aishell_wav_scp=aishell_test.scp
text=$data/test/text
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
u2_recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$exp/cmvn.ark \
--model_path=$model_dir/export.jit \
--vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer.ark
cat $data/split${nj}/*/result_recognizer.ark > $exp/aishell_recognizer
utils/compute-wer.py --char=1 --v=1 $text $exp/aishell_recognizer > $exp/aishell.recognizer.err
echo "recognizer test have finished!!!"
echo "please checkout in $exp/aishell.recognizer.err"
tail -n 7 $exp/aishell.recognizer.err

@ -0,0 +1,30 @@
#!/usr/bin/env bash
set -eo pipefail
data=$1
scp=$2
split_name=$3
numsplit=$4
# save in $data/split{n}
# $scp to split
#
if [[ ! $numsplit -gt 0 ]]; then
echo "$0: Invalid num-split argument";
exit 1;
fi
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done)
# if this mkdir fails due to argument-list being too long, iterate.
if ! mkdir -p $directories >&/dev/null; then
for n in `seq $numsplit`; do
mkdir -p $data/split${numsplit}/$n
done
fi
echo "utils/split_scp.pl $scp $scp_splits"
utils/split_scp.pl $scp $scp_splits

@ -0,0 +1,18 @@
# This contains the locations of binarys build required for running the examples.
unset GREP_OPTIONS
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer
PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH

@ -0,0 +1,76 @@
#!/bin/bash
set +x
set -e
. path.sh
nj=40
stage=0
stop_stage=5
. utils/parse_options.sh
# input
data=data
exp=exp
mkdir -p $exp $data
# 1. compile
if [ ! -d ${SPEECHX_BUILD} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
ckpt_dir=$data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
# download model
if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
popd
fi
# test wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p $data
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# aishell wav scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
./local/feat.sh
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/decode.sh
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./loca/recognizer.sh
fi

@ -32,6 +32,12 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder
) )
add_subdirectory(decoder) add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/recognizer
)
add_subdirectory(recognizer)
include_directories( include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/protocol ${CMAKE_CURRENT_SOURCE_DIR}/protocol

@ -14,47 +14,47 @@
#pragma once #pragma once
#include "kaldi/base/kaldi-types.h"
#include <limits> #include <limits>
#include "kaldi/base/kaldi-types.h"
typedef float BaseFloat; typedef float BaseFloat;
typedef double double64; typedef double double64;
typedef signed char int8; typedef signed char int8;
typedef short int16; typedef short int16; // NOLINT
typedef int int32; typedef int int32; // NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef long int64; typedef long int64; // NOLINT
#else #else
typedef long long int64; typedef long long int64; // NOLINT
#endif #endif
typedef unsigned char uint8; typedef unsigned char uint8; // NOLINT
typedef unsigned short uint16; typedef unsigned short uint16; // NOLINT
typedef unsigned int uint32; typedef unsigned int uint32; // NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64; typedef unsigned long uint64; // NOLINT
#else #else
typedef unsigned long long uint64; typedef unsigned long long uint64; // NOLINT
#endif #endif
typedef signed int char32; typedef signed int char32;
const uint8 kuint8max = ((uint8)0xFF); const uint8 kuint8max = static_cast<uint8>(0xFF);
const uint16 kuint16max = ((uint16)0xFFFF); const uint16 kuint16max = static_cast<uint16>(0xFFFF);
const uint32 kuint32max = ((uint32)0xFFFFFFFF); const uint32 kuint32max = static_cast<uint32>(0xFFFFFFFF);
const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL)); const uint64 kuint64max = static_cast<uint64>(0xFFFFFFFFFFFFFFFFLL);
const int8 kint8min = ((int8)0x80); const int8 kint8min = static_cast<int8>(0x80);
const int8 kint8max = ((int8)0x7F); const int8 kint8max = static_cast<int8>(0x7F);
const int16 kint16min = ((int16)0x8000); const int16 kint16min = static_cast<int16>(0x8000);
const int16 kint16max = ((int16)0x7FFF); const int16 kint16max = static_cast<int16>(0x7FFF);
const int32 kint32min = ((int32)0x80000000); const int32 kint32min = static_cast<int32>(0x80000000);
const int32 kint32max = ((int32)0x7FFFFFFF); const int32 kint32max = static_cast<int32>(0x7FFFFFFF);
const int64 kint64min = ((int64)(0x8000000000000000LL)); const int64 kint64min = static_cast<int64>(0x8000000000000000LL);
const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL)); const int64 kint64max = static_cast<int64>(0x7FFFFFFFFFFFFFFFLL);
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max(); const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min(); const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();

@ -14,21 +14,30 @@
#pragma once #pragma once
#include <algorithm>
#include <cassert>
#include <cmath>
#include <condition_variable> #include <condition_variable>
#include <cstring>
#include <deque> #include <deque>
#include <fstream> #include <fstream>
#include <iomanip>
#include <iostream> #include <iostream>
#include <istream> #include <istream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <numeric>
#include <ostream> #include <ostream>
#include <queue> #include <queue>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <stack> #include <stack>
#include <stdexcept>
#include <string> #include <string>
#include <thread> #include <thread>
#include <tuple>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
@ -38,3 +47,5 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "base/macros.h" #include "base/macros.h"
#include "utils/file_utils.h"
#include "utils/math.h"

@ -14,6 +14,9 @@
#pragma once #pragma once
#include <limits>
#include <string>
namespace ppspeech { namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN #ifndef DISALLOW_COPY_AND_ASSIGN
@ -22,4 +25,7 @@ namespace ppspeech {
void operator=(const TypeName&) = delete void operator=(const TypeName&) = delete
#endif #endif
} // namespace pp_speech // kSpaceSymbol in UTF-8 is: ▁
const char kSpaceSymbo[] = "\xe2\x96\x81";
} // namespace ppspeech

@ -35,7 +35,7 @@
class ThreadPool { class ThreadPool {
public: public:
ThreadPool(size_t); explicit ThreadPool(size_t);
template <class F, class... Args> template <class F, class... Args>
auto enqueue(F&& f, Args&&... args) auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>; -> std::future<typename std::result_of<F(Args...)>::type>;

@ -17,7 +17,7 @@
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
// Initialize Googles logging library. // Initialize Googles logging library.
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
LOG(INFO) << "Found " << 10 << " cookies"; LOG(INFO) << "Found " << 10 << " cookies";

@ -21,6 +21,7 @@
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
@ -63,8 +64,8 @@ void model_forward_test() {
; ;
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
CHECK(model_graph != ""); CHECK_NE(model_graph, "");
CHECK(model_params != ""); CHECK_NE(model_params, "");
cout << "model path: " << model_graph << endl; cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl; cout << "model param path : " << model_params << endl;
@ -195,8 +196,11 @@ void model_forward_test() {
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
model_forward_test(); model_forward_test();
return 0; return 0;

@ -1,21 +1,32 @@
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder STATIC
ctc_beam_search_decoder.cc set(srcs)
if (USING_DS2)
list(APPEND srcs
ctc_decoders/decoder_utils.cpp ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc
ctc_tlg_decoder.cc ctc_tlg_decoder.cc
recognizer.cc
) )
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) endif()
if (USING_U2)
list(APPEND srcs
ctc_prefix_beam_search_decoder.cc
)
endif()
add_library(decoder STATIC ${srcs})
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
# test
if (USING_DS2)
set(BINS set(BINS
ctc_prefix_beam_search_decoder_main ctc_beam_search_decoder_main
nnet_logprob_decoder_main nnet_logprob_decoder_main
recognizer_main ctc_tlg_decoder_main
tlg_decoder_main
) )
foreach(bin_name IN LISTS BINS) foreach(bin_name IN LISTS BINS)
@ -23,3 +34,22 @@ foreach(bin_name IN LISTS BINS)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach() endforeach()
endif()
if (USING_U2)
set(TEST_BINS
ctc_prefix_beam_search_decoder_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach()
endif()

@ -1,3 +1,4 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
@ -12,10 +13,36 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "base/basic_types.h" #pragma once
#include "base/common.h"
struct DecoderResult { struct DecoderResult {
BaseFloat acoustic_score; BaseFloat acoustic_score;
std::vector<int32> words_idx; std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp; std::vector<std::pair<int32, int32>> time_stamp;
};
namespace ppspeech {
struct WordPiece {
std::string word;
int start = -1;
int end = -1;
WordPiece(std::string word, int start, int end)
: word(std::move(word)), start(start), end(end) {}
}; };
struct DecodeResult {
float score = -kBaseFloatMax;
std::string sentence;
std::vector<WordPiece> word_pieces;
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
return a.score > b.score;
}
};
} // namespace ppspeech

@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h" #include "base/common.h"
#include "decoder/ctc_decoders/decoder_utils.h" #include "decoder/ctc_decoders/decoder_utils.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
@ -24,12 +25,7 @@ using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts), : opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) {
init_ext_scorer_(nullptr),
blank_id_(-1),
space_id_(-1),
num_frame_decoded_(0),
root_(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file; LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed"; LOG(INFO) << "load the dict failed";
@ -43,12 +39,12 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
} }
blank_id_ = 0; CHECK_EQ(opts_.blank, 0);
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin(); space_id_ = it - vocabulary_.begin();
// if no space in vocabulary // if no space in vocabulary
if ((size_t)space_id_ >= vocabulary_.size()) { if (static_cast<size_t>(space_id_) >= vocabulary_.size()) {
space_id_ = -2; space_id_ = -2;
} }
} }
@ -84,8 +80,6 @@ void CTCBeamSearch::Decode(
return; return;
} }
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; }
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode( void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
@ -110,17 +104,21 @@ void CTCBeamSearch::ResetPrefixes() {
} }
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs, int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) { const vector<string>& nbest_words) {
kaldi::Timer timer; kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs); AdvanceDecoding(probs);
LOG(INFO) << "ctc decoding elapsed time(s) " LOG(INFO) << "ctc decoding elapsed time(s) "
<< static_cast<float>(timer.Elapsed()) / 1000.0f; << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0; return 0;
} }
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath(int n) {
int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size);
return get_beam_search_result(prefixes_, vocabulary_, beam_size);
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() { vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); return GetNBestPath(-1);
} }
string CTCBeamSearch::GetBestPath() { string CTCBeamSearch::GetBestPath() {
@ -167,7 +165,7 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
continue; continue;
} }
min_cutoff = prefixes_[num_prefixes_ - 1]->score + min_cutoff = prefixes_[num_prefixes_ - 1]->score +
std::log(prob[blank_id_]) - std::log(prob[opts_.blank]) -
std::max(0.0, init_ext_scorer_->beta); std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes_ == beam_size); full_beam = (num_prefixes_ == beam_size);
@ -195,9 +193,9 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
for (size_t i = beam_size; i < prefixes_.size(); ++i) { for (size_t i = beam_size; i < prefixes_.size(); ++i) {
prefixes_[i]->remove(); prefixes_[i]->remove();
} }
} // if } // end if
num_frame_decoded_++; num_frame_decoded_++;
} // for probs_seq } // end for probs_seq
} }
int32 CTCBeamSearch::SearchOneChar( int32 CTCBeamSearch::SearchOneChar(
@ -215,7 +213,7 @@ int32 CTCBeamSearch::SearchOneChar(
break; break;
} }
if (c == blank_id_) { if (c == opts_.blank) {
prefix->log_prob_b_cur = prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue; continue;

@ -12,67 +12,47 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "base/common.h" // used by deepspeech2
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h"
#pragma once #pragma once
namespace ppspeech { #include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/decoder_itf.h"
struct CTCBeamSearchOptions { namespace ppspeech {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions()
: dict_file("vocab.txt"),
lm_path(""),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch { class CTCBeamSearch : public DecoderBase {
public: public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {} ~CTCBeamSearch() {}
void InitDecoder(); void InitDecoder();
void Reset();
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable); void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath(); std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(int n);
std::string GetFinalBestPath(); std::string GetFinalBestPath();
int NumFrameDecoded();
std::string GetPartialResult() {
CHECK(false) << "Not implement.";
return {};
}
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); const std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private: private:
void ResetPrefixes(); void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam, int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx, const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff); const BaseFloat& min_cutoff);
@ -83,12 +63,11 @@ class CTCBeamSearch {
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<std::string> vocabulary_; // todo remove later std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id_;
int space_id_; int space_id_;
std::shared_ptr<PathTrie> root_; std::shared_ptr<PathTrie> root_;
std::vector<PathTrie*> prefixes_; std::vector<PathTrie*> prefixes_;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
}; };
} // namespace basr } // namespace ppspeech

@ -12,29 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// todo refactor, repalce with gtest // used by deepspeech2
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(lm_path, "", "language model");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_string( DEFINE_string(
@ -48,59 +45,59 @@ DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box", "chunk_state_h_box,chunk_state_c_box",
"model cache names"); "model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// test TLG decoder by feeding speech feature. // test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, "");
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path; std::string model_path = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
std::string word_symbol_table = FLAGS_word_symbol_table; std::string dict_file = FLAGS_dict_file;
std::string graph_path = FLAGS_graph_path; std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_graph; LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params; LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "word symbol path: " << word_symbol_table; LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "graph path: " << graph_path; LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts; ppspeech::CTCBeamSearchOptions opts;
opts.word_symbol_table = word_symbol_table; opts.dict_file = dict_file;
opts.fst_path = graph_path; opts.lm_path = lm_path;
opts.opts.max_active = FLAGS_max_active; ppspeech::CTCBeamSearch decoder(opts);
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5; ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache()); std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length; LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder(); decoder.InitDecoder();
kaldi::Timer timer; kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); string utt = feature_reader.Key();
@ -132,6 +129,7 @@ int main(int argc, char* argv[]) {
if (feature_chunk_size < receptive_field_length) break; if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride; int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) { for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start); kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp( kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
@ -161,10 +159,9 @@ int main(int argc, char* argv[]) {
++num_done; ++num_done;
} }
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
} }

@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "util/parse-options.h"
namespace ppspeech {
struct CTCBeamSearchOptions {
// common
int blank;
// ds2
std::string dict_file;
std::string lm_path;
int beam_size;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int cutoff_top_n;
int num_proc_bsearch;
// u2
int first_beam_size;
int second_beam_size;
CTCBeamSearchOptions()
: blank(0),
dict_file("vocab.txt"),
lm_path(""),
beam_size(300),
alpha(1.9f),
beta(5.0),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10),
first_beam_size(10),
second_beam_size(10) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "Ds2BeamSearchConfig: ";
opts->Register("dict", &dict_file, module + "vocab file path.");
opts->Register(
"lm-path", &lm_path, module + "ngram language model path.");
opts->Register("alpha", &alpha, module + "alpha");
opts->Register("beta", &beta, module + "beta");
opts->Register("beam-size",
&beam_size,
module + "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
opts->Register("blank", &blank, "blank id, default is 0.");
module = "U2BeamSearchConfig: ";
opts->Register(
"first-beam-size", &first_beam_size, module + "first beam size.");
opts->Register("second-beam-size",
&second_beam_size,
module + "second beam size.");
}
};
} // namespace ppspeech

@ -0,0 +1,370 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "absl/strings/str_join.h"
#include "base/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "utils/math.h"
#ifdef USE_PROFILING
#include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType;
#endif
namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts)
: opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(vocab_path));
CHECK(unit_table_ != nullptr);
Reset();
}
void CTCPrefixBeamSearch::Reset() {
num_frame_decoded_ = 0;
cur_hyps_.clear();
hypotheses_.clear();
likelihood_.clear();
viterbi_likelihood_.clear();
times_.clear();
outputs_.clear();
// empty hyp with Score
std::vector<int> empty;
PrefixScore prefix_score;
prefix_score.InitEmpty();
cur_hyps_[empty] = prefix_score;
outputs_.emplace_back(empty);
hypotheses_.emplace_back(empty);
likelihood_.emplace_back(prefix_score.TotalScore());
times_.emplace_back(empty);
}
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
// forward frame by frame
std::vector<kaldi::BaseFloat> frame_prob;
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) {
VLOG(1) << "decoder advance decode exit." << frame_prob.size();
break;
}
std::vector<std::vector<kaldi::BaseFloat>> likelihood;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
}
}
static bool PrefixScoreCompare(
const std::pair<std::vector<int>, PrefixScore>& a,
const std::pair<std::vector<int>, PrefixScore>& b) {
// log domain
return a.second.TotalScore() > b.second.TotalScore();
}
void CTCPrefixBeamSearch::AdvanceDecoding(
const std::vector<std::vector<kaldi::BaseFloat>>& logp) {
#ifdef USE_PROFILING
RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding",
TracerEventType::UserDefined,
1);
#endif
if (logp.size() == 0) return;
int first_beam_size =
std::min(static_cast<int>(logp[0].size()), opts_.first_beam_size);
for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_) {
const std::vector<kaldi::BaseFloat>& logp_t = logp[t];
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
next_hyps;
// 1. first beam prune, only select topk candidates
std::vector<kaldi::BaseFloat> topk_score;
std::vector<int32_t> topk_index;
TopK(logp_t, first_beam_size, &topk_score, &topk_index);
VLOG(2) << "topk: " << num_frame_decoded_ << " "
<< *std::max_element(logp_t.begin(), logp_t.end()) << " "
<< topk_score[0];
for (int i = 0; i < topk_score.size(); i++) {
VLOG(2) << "topk: " << num_frame_decoded_ << " " << topk_score[i];
}
// 2. token passing
for (int i = 0; i < topk_index.size(); ++i) {
int id = topk_index[i];
auto prob = topk_score[i];
for (const auto& it : cur_hyps_) {
const std::vector<int>& prefix = it.first;
const PrefixScore& prefix_score = it.second;
// If prefix doesn't exist in next_hyps, next_hyps[prefix] will
// insert
// PrefixScore(-inf, -inf) by default, since the default
// constructor
// of PrefixScore will set fields b(blank ending Score) and
// nb(none blank ending Score) to -inf, respectively.
if (id == opts_.blank) {
// case 0: *a + <blank> => *a, *a<blank> + <blank> => *a,
// prefix not
// change
PrefixScore& next_score = next_hyps[prefix];
next_score.b =
LogSumExp(next_score.b, prefix_score.Score() + prob);
// timestamp, blank is slince, not effact timestamp
next_score.v_b = prefix_score.ViterbiScore() + prob;
next_score.times_b = prefix_score.Times();
// Prefix not changed, copy the context from pefix
if (context_graph_ && !next_score.has_context) {
next_score.CopyContext(prefix_score);
next_score.has_context = true;
}
} else if (!prefix.empty() && id == prefix.back()) {
// case 1: *a + a => *a, prefix not changed
PrefixScore& next_score1 = next_hyps[prefix];
next_score1.nb =
LogSumExp(next_score1.nb, prefix_score.nb + prob);
// timestamp, non-blank symbol effact timestamp
if (next_score1.v_nb < prefix_score.v_nb + prob) {
// compute viterbi Score
next_score1.v_nb = prefix_score.v_nb + prob;
if (next_score1.cur_token_prob < prob) {
// store max token prob
next_score1.cur_token_prob = prob;
// update this timestamp as token appeared here.
next_score1.times_nb = prefix_score.times_nb;
assert(next_score1.times_nb.size() > 0);
next_score1.times_nb.back() = num_frame_decoded_;
}
}
// Prefix not changed, copy the context from pefix
if (context_graph_ && !next_score1.has_context) {
next_score1.CopyContext(prefix_score);
next_score1.has_context = true;
}
// case 2: *a<blank> + a => *aa, prefix changed.
std::vector<int> new_prefix(prefix);
new_prefix.emplace_back(id);
PrefixScore& next_score2 = next_hyps[new_prefix];
next_score2.nb =
LogSumExp(next_score2.nb, prefix_score.b + prob);
// timestamp, non-blank symbol effact timestamp
if (next_score2.v_nb < prefix_score.v_b + prob) {
// compute viterbi Score
next_score2.v_nb = prefix_score.v_b + prob;
// new token added
next_score2.cur_token_prob = prob;
next_score2.times_nb = prefix_score.times_b;
next_score2.times_nb.emplace_back(num_frame_decoded_);
}
// Prefix changed, calculate the context Score.
if (context_graph_ && !next_score2.has_context) {
next_score2.UpdateContext(
context_graph_, prefix_score, id, prefix.size());
next_score2.has_context = true;
}
} else {
// id != prefix.back()
// case 3: *a + b => *ab, *a<blank> +b => *ab
std::vector<int> new_prefix(prefix);
new_prefix.emplace_back(id);
PrefixScore& next_score = next_hyps[new_prefix];
next_score.nb =
LogSumExp(next_score.nb, prefix_score.Score() + prob);
// timetamp, non-blank symbol effact timestamp
if (next_score.v_nb < prefix_score.ViterbiScore() + prob) {
next_score.v_nb = prefix_score.ViterbiScore() + prob;
next_score.cur_token_prob = prob;
next_score.times_nb = prefix_score.Times();
next_score.times_nb.emplace_back(num_frame_decoded_);
}
// Prefix changed, calculate the context Score.
if (context_graph_ && !next_score.has_context) {
next_score.UpdateContext(
context_graph_, prefix_score, id, prefix.size());
next_score.has_context = true;
}
}
} // end for (const auto& it : cur_hyps_)
} // end for (int i = 0; i < topk_index.size(); ++i)
// 3. second beam prune, only keep top n best paths
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(
next_hyps.begin(), next_hyps.end());
int second_beam_size =
std::min(static_cast<int>(arr.size()), opts_.second_beam_size);
std::nth_element(arr.begin(),
arr.begin() + second_beam_size,
arr.end(),
PrefixScoreCompare);
arr.resize(second_beam_size);
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
// 4. update cur_hyps by next_hyps, and get new result
UpdateHypotheses(arr);
} // end for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_)
}
void CTCPrefixBeamSearch::UpdateHypotheses(
const std::vector<std::pair<std::vector<int>, PrefixScore>>& hyps) {
cur_hyps_.clear();
outputs_.clear();
hypotheses_.clear();
likelihood_.clear();
viterbi_likelihood_.clear();
times_.clear();
for (auto& item : hyps) {
cur_hyps_[item.first] = item.second;
UpdateOutputs(item);
hypotheses_.emplace_back(std::move(item.first));
likelihood_.emplace_back(item.second.TotalScore());
viterbi_likelihood_.emplace_back(item.second.ViterbiScore());
times_.emplace_back(item.second.Times());
}
}
void CTCPrefixBeamSearch::UpdateOutputs(
const std::pair<std::vector<int>, PrefixScore>& prefix) {
const std::vector<int>& input = prefix.first;
const std::vector<int>& start_boundaries = prefix.second.start_boundaries;
const std::vector<int>& end_boundaries = prefix.second.end_boundaries;
// add <context> </context> tag
std::vector<int> output;
int s = 0;
int e = 0;
for (int i = 0; i < input.size(); ++i) {
output.emplace_back(input[i]);
}
outputs_.emplace_back(output);
}
void CTCPrefixBeamSearch::FinalizeSearch() {
UpdateFinalContext();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
int cnt = 0;
for (int i = 0; i < hypotheses_.size(); i++) {
VLOG(2) << "hyp " << cnt << " len: " << hypotheses_[i].size()
<< " ctc score: " << likelihood_[i];
for (int j = 0; j < hypotheses_[i].size(); j++) {
VLOG(2) << hypotheses_[i][j];
}
}
}
void CTCPrefixBeamSearch::UpdateFinalContext() {
if (context_graph_ == nullptr) return;
CHECK(hypotheses_.size() == cur_hyps_.size());
CHECK(hypotheses_.size() == likelihood_.size());
// We should backoff the context Score/state when the context is
// not fully matched at the last time.
for (const auto& prefix : hypotheses_) {
PrefixScore& prefix_score = cur_hyps_[prefix];
if (prefix_score.context_score != 0) {
prefix_score.UpdateContext(
context_graph_, prefix_score, 0, prefix.size());
}
}
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(),
cur_hyps_.end());
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
// Update cur_hyps_ and get new result
UpdateHypotheses(arr);
}
std::string CTCPrefixBeamSearch::GetBestPath(int index) {
int n_hyps = Outputs().size();
CHECK_GT(n_hyps, 0);
CHECK_LT(index, n_hyps);
std::vector<int> one = Outputs()[index];
std::string sentence;
for (int i = 0; i < one.size(); i++) {
sentence += unit_table_->Find(one[i]);
}
return sentence;
}
std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); }
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath(
int n) {
int hyps_size = hypotheses_.size();
CHECK_GT(hyps_size, 0);
int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size);
std::vector<std::pair<double, std::string>> n_best;
n_best.reserve(min_n);
for (int i = 0; i < min_n; i++) {
n_best.emplace_back(Likelihood()[i], GetBestPath(i));
}
return n_best;
}
std::vector<std::pair<double, std::string>>
CTCPrefixBeamSearch::GetNBestPath() {
return GetNBestPath(-1);
}
std::string CTCPrefixBeamSearch::GetFinalBestPath() { return GetBestPath(); }
std::string CTCPrefixBeamSearch::GetPartialResult() { return GetBestPath(); }
} // namespace ppspeech

@ -0,0 +1,101 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/ctc_prefix_beam_search.cc
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "fst/symbol-table.h"
namespace ppspeech {
class ContextGraph;
class CTCPrefixBeamSearch : public DecoderBase {
public:
CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {}
SearchType Type() const { return SearchType::kPrefixBeamSearch; }
void InitDecoder() override;
void Reset() override;
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;
std::string GetFinalBestPath() override;
std::string GetPartialResult() override;
void FinalizeSearch();
const std::shared_ptr<fst::SymbolTable> VocabTable() const {
return unit_table_;
}
const std::vector<std::vector<int>>& Inputs() const { return hypotheses_; }
const std::vector<std::vector<int>>& Outputs() const { return outputs_; }
const std::vector<float>& Likelihood() const { return likelihood_; }
const std::vector<float>& ViterbiLikelihood() const {
return viterbi_likelihood_;
}
const std::vector<std::vector<int>>& Times() const { return times_; }
protected:
std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override;
private:
std::string GetBestPath(int index);
void AdvanceDecoding(
const std::vector<std::vector<kaldi::BaseFloat>>& logp);
void UpdateOutputs(const std::pair<std::vector<int>, PrefixScore>& prefix);
void UpdateHypotheses(
const std::vector<std::pair<std::vector<int>, PrefixScore>>& prefix);
void UpdateFinalContext();
private:
CTCBeamSearchOptions opts_;
std::shared_ptr<fst::SymbolTable> unit_table_{nullptr};
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
cur_hyps_;
// n-best list and corresponding likelihood, in sorted order
std::vector<std::vector<int>> hypotheses_;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
std::vector<float> viterbi_likelihood_;
// Outputs contain the hypotheses_ and tags lik: <context> and </context>
std::vector<std::vector<int>> outputs_;
std::shared_ptr<ContextGraph> context_graph_{nullptr};
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
};
} // namespace ppspeech

@ -12,40 +12,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// todo refactor, repalce with gtest #include "absl/strings/str_split.h"
#include "base/common.h"
#include "base/flags.h" #include "decoder/ctc_prefix_beam_search_decoder.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "fst/symbol-table.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/u2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(vocab_path, "", "vocab path");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(model_path, "", "paddle nnet model");
DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
model_input_names, DEFINE_int32(nnet_decoder_chunk, 16, "paddle nnet forward chunk");
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
@ -53,117 +42,138 @@ using std::vector;
// test ds2 online decoder by feeding speech feature // test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK(FLAGS_result_wspecifier != ""); int32 num_done = 0, num_err = 0;
CHECK(FLAGS_feature_rspecifier != "");
CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, "");
CHECK_NE(FLAGS_vocab_path, "");
CHECK_NE(FLAGS_model_path, "");
LOG(INFO) << "model path: " << FLAGS_model_path;
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_path = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0; // nnet
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
std::shared_ptr<ppspeech::U2Nnet> nnet =
std::make_shared<ppspeech::U2Nnet>(model_opts);
// decodeable
std::shared_ptr<ppspeech::DataCache> raw_data =
std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet, raw_data);
// decoder
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file; opts.blank = 0;
opts.lm_path = lm_path; opts.first_beam_size = 10;
ppspeech::CTCBeamSearch decoder(opts); opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_path;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length; LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder(); decoder.InitDecoder();
kaldi::Timer timer; kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value(); kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0; int nframes = feature.NumRows();
int32 padding_len = 0; int feat_dim = feature.NumCols();
raw_data->SetDim(feat_dim);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim;
raw_data->SetDim(feat_dim);
int32 ori_feature_len = feature.NumRows(); int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { int32 num_chunks = feature.NumRows() / chunk_stride + 1;
padding_len = LOG(INFO) << "num_chunks: " << num_chunks;
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size * int32 this_chunk_size = 0;
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) { if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min( this_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size); ori_feature_len - chunk_idx * chunk_stride, chunk_size);
} }
if (feature_chunk_size < receptive_field_length) break; if (this_chunk_size < receptive_field_length) {
LOG(WARNING)
<< "utt: " << utt << " skip last " << this_chunk_size
<< " frames, expect is " << receptive_field_length;
break;
}
kaldi::Vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
int32 start = chunk_idx * chunk_stride; int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim);
for (int row_id = 0; row_id < chunk_size; ++row_id) { feature_chunk_row.CopyFromVec(feat_row);
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start; ++start;
} }
// feat to frontend pipeline cache
raw_data->Accept(feature_chunk); raw_data->Accept(feature_chunk);
// send data finish signal
if (chunk_idx == num_chunks - 1) { if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished(); raw_data->SetFinished();
} }
// forward nnet
decoder.AdvanceDecode(decodable); decoder.AdvanceDecode(decodable);
LOG(INFO) << "Partial result: " << decoder.GetPartialResult();
} }
std::string result;
result = decoder.GetFinalBestPath(); decoder.FinalizeSearch();
// get 1-best result
std::string result = decoder.GetFinalBestPath();
// after process one utt, then reset state.
decodable->Reset(); decodable->Reset();
decoder.Reset(); decoder.Reset();
if (result.empty()) { if (result.empty()) {
// the TokenWriter can not write empty string. // the TokenWriter can not write empty string.
++num_err; ++num_err;
KALDI_LOG << " the result of " << utt << " is empty"; LOG(INFO) << " the result of " << utt << " is empty";
continue; continue;
} }
KALDI_LOG << " the result of " << utt << " is " << result;
LOG(INFO) << " the result of " << utt << " is " << result;
result_writer.Write(utt, result); result_writer.Write(utt, result);
++num_done; ++num_done;
} }
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
double elapsed = timer.Elapsed(); double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s"; LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
} }

@ -0,0 +1,98 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/ctc_prefix_beam_search.h
#pragma once
#include "base/common.h"
#include "utils/math.h"
namespace ppspeech {
class ContextGraph;
struct PrefixScore {
// decoding, unit in log scale
float b = -kBaseFloatMax; // blank ending score
float nb = -kBaseFloatMax; // none-blank ending score
// decoding score, sum
float Score() const { return LogSumExp(b, nb); }
// timestamp, unit in log sclae
float v_b = -kBaseFloatMax; // viterbi blank ending score
float v_nb = -kBaseFloatMax; // niterbi none-blank ending score
float cur_token_prob = -kBaseFloatMax; // prob of current token
std::vector<int> times_b; // times of viterbi blank path
std::vector<int> times_nb; // times of viterbi non-blank path
// timestamp score, max
float ViterbiScore() const { return std::max(v_b, v_nb); }
// get timestamp
const std::vector<int>& Times() const {
return v_b > v_nb ? times_b : times_nb;
}
// context state
bool has_context = false;
int context_state = 0;
float context_score = 0;
std::vector<int> start_boundaries;
std::vector<int> end_boundaries;
// decodign score with context bias
float TotalScore() const { return Score() + context_score; }
void CopyContext(const PrefixScore& prefix_score) {
context_state = prefix_score.context_state;
context_score = prefix_score.context_score;
start_boundaries = prefix_score.start_boundaries;
end_boundaries = prefix_score.end_boundaries;
}
void UpdateContext(const std::shared_ptr<ContextGraph>& constext_graph,
const PrefixScore& prefix_score,
int word_id,
int prefix_len) {
CHECK(false);
}
void InitEmpty() {
b = 0.0f; // log(1)
nb = -kBaseFloatMax; // log(0)
v_b = 0.0f; // log(1)
v_nb = 0.0f; // log(1)
}
};
struct PrefixScoreHash {
// https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
std::size_t operator()(const std::vector<int>& prefix) const {
std::size_t seed = prefix.size();
for (auto& i : prefix) {
seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
using PrefixWithScoreType = std::pair<std::vector<int>, PrefixScoreHash>;
} // namespace ppspeech

@ -18,37 +18,38 @@ namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path)); fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr); CHECK(fst_ != nullptr);
word_symbol_table_.reset( word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table)); fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding();
frame_decoded_size_ = 0; Reset();
} }
void TLGDecoder::InitDecoder() { void TLGDecoder::Reset() {
decoder_->InitDecoding(); decoder_->InitDecoding();
frame_decoded_size_ = 0; num_frame_decoded_ = 0;
return;
} }
void TLGDecoder::InitDecoder() { Reset(); }
void TLGDecoder::AdvanceDecode( void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) { while (!decodable->IsLastFrame(num_frame_decoded_)) {
AdvanceDecoding(decodable.get()); AdvanceDecoding(decodable.get());
} }
} }
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
decoder_->AdvanceDecoding(decodable, 1); decoder_->AdvanceDecoding(decodable, 1);
frame_decoded_size_++; num_frame_decoded_++;
} }
void TLGDecoder::Reset() {
InitDecoder();
return;
}
std::string TLGDecoder::GetPartialResult() { std::string TLGDecoder::GetPartialResult() {
if (frame_decoded_size_ == 0) { if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.") // BestPathEnd if no frames were decoded.")
return std::string(""); return std::string("");
@ -68,7 +69,7 @@ std::string TLGDecoder::GetPartialResult() {
} }
std::string TLGDecoder::GetFinalBestPath() { std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) { if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.") // BestPathEnd if no frames were decoded.")
return std::string(""); return std::string("");
@ -88,4 +89,5 @@ std::string TLGDecoder::GetFinalBestPath() {
} }
return words; return words;
} }
}
} // namespace ppspeech

@ -14,37 +14,78 @@
#pragma once #pragma once
#include "base/basic_types.h" #include "base/common.h"
#include "kaldi/decoder/decodable-itf.h" #include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h" #include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h" #include "util/parse-options.h"
DECLARE_string(graph_path);
DECLARE_string(word_symbol_table);
DECLARE_int32(max_active);
DECLARE_double(beam);
DECLARE_double(lattice_beam);
namespace ppspeech { namespace ppspeech {
struct TLGDecoderOptions { struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts; kaldi::LatticeFasterDecoderConfig opts{};
// todo remove later, add into decode resource // todo remove later, add into decode resource
std::string word_symbol_table; std::string word_symbol_table;
std::string fst_path; std::string fst_path;
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
LOG(INFO) << "LatticeFasterDecoder max active: "
<< decoder_opts.opts.max_active;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam;
LOG(INFO) << "LatticeFasterDecoder lattice_beam: "
<< decoder_opts.opts.lattice_beam;
return decoder_opts;
}
}; };
class TLGDecoder { class TLGDecoder : public DecoderBase {
public: public:
explicit TLGDecoder(TLGDecoderOptions opts); explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default;
void InitDecoder(); void InitDecoder();
void Decode(); void Reset();
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
std::string GetPartialResult();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode( void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable); const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
void Decode();
std::string GetFinalBestPath() override;
std::string GetPartialResult() override;
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
const std::vector<std::string>& nbest_words);
protected:
std::string GetBestPath() override {
CHECK(false);
return {};
}
std::vector<std::pair<double, std::string>> GetNBestPath() override {
CHECK(false);
return {};
}
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override {
CHECK(false);
return {};
}
private: private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable); void AdvanceDecoding(kaldi::DecodableInterface* decodable);
@ -52,8 +93,6 @@ class TLGDecoder {
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_; std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_; std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_; std::shared_ptr<fst::SymbolTable> word_symbol_table_;
// the frame size which have decoded starts from 0.
int32 frame_decoded_size_;
}; };

@ -0,0 +1,137 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/param.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
enum SearchType {
kPrefixBeamSearch = 0,
kWfstBeamSearch = 1,
};
class DecoderInterface {
public:
virtual ~DecoderInterface() {}
virtual void InitDecoder() = 0;
virtual void Reset() = 0;
// call AdvanceDecoding
virtual void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
// call GetBestPath
virtual std::string GetFinalBestPath() = 0;
virtual std::string GetPartialResult() = 0;
protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
// virtual void Decode() = 0;
virtual std::string GetBestPath() = 0;
virtual std::vector<std::pair<double, std::string>> GetNBestPath() = 0;
virtual std::vector<std::pair<double, std::string>> GetNBestPath(int n) = 0;
};
class DecoderBase : public DecoderInterface {
protected:
// start from one
int NumFrameDecoded() { return num_frame_decoded_ + 1; }
protected:
// current decoding frame number, abs_time_step_
int32 num_frame_decoded_;
};
} // namespace ppspeech

@ -30,8 +30,11 @@ using std::vector;
// test decoder by feeding nnet posterior probability // test decoder by feeding nnet posterior probability
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader likelihood_reader( kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
FLAGS_nnet_prob_respecifier); FLAGS_nnet_prob_respecifier);

@ -17,23 +17,29 @@
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
// feature // feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
DEFINE_bool(fill_zero,
false,
"fill zero at last chunk, when chunk < chunk_size");
// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear // DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear
// feature, or fbank"); // feature, or fbank");
DEFINE_int32(num_bins, 161, "num bins of mel"); DEFINE_int32(num_bins, 161, "num bins of mel");
DEFINE_string(cmvn_file, "", "read cmvn"); DEFINE_string(cmvn_file, "", "read cmvn");
// feature sliding window // feature sliding window
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet // nnet
DEFINE_string(vocab_path, "", "nnet vocab path.");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string( DEFINE_string(
@ -48,71 +54,30 @@ DEFINE_string(model_cache_names,
"model cache names"); "model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
// decoder // decoder
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_int32(max_active, 7500, "max active"); DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.dither = 0.0;
frame_opts.frame_shift_ms = 10;
opts.use_fbank = FLAGS_use_fbank;
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opts.fbank_opts.frame_opts = frame_opts;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
opts.assembler_opts.subsampling_rate = FLAGS_downsampling_rate;
opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
return opts;
}
ModelOptions InitModelOptions() {
ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
return model_opts;
}
TLGDecoderOptions InitDecoderOptions() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
return decoder_opts;
}
RecognizerResource InitRecognizerResoure() { // DecodeOptions flags
RecognizerResource resource; // DEFINE_int32(chunk_size, -1, "decoding chunk size");
resource.acoustic_scale = FLAGS_acoustic_scale; DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
resource.feature_pipeline_opts = InitFeaturePipelineOptions(); DEFINE_double(ctc_weight,
resource.model_opts = InitModelOptions(); 0.5,
resource.tlg_opts = InitDecoderOptions(); "ctc weight when combining ctc score and rescoring score");
return resource; DEFINE_double(rescoring_weight,
} 1.0,
} "rescoring weight when combining ctc score and rescoring score");
DEFINE_double(reverse_weight,
0.3,
"used for bitransformer rescoring. it must be 0.0 if decoder is"
"conventional transformer decoder, and only reverse_weight > 0.0"
"dose the right to left decoder will be calculated and used");
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
DEFINE_int32(blank, 0, "blank id in vocab");

@ -1,5 +1,3 @@
project(frontend)
add_library(frontend STATIC add_library(frontend STATIC
cmvn.cc cmvn.cc
db_norm.cc db_norm.cc

@ -16,16 +16,18 @@
namespace ppspeech { namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::unique_ptr; using std::unique_ptr;
Assembler::Assembler(AssemblerOptions opts, Assembler::Assembler(AssemblerOptions opts,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
fill_zero_ = opts.fill_zero;
frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk; frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk;
frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate + frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate +
opts.receptive_filed_length; opts.receptive_filed_length;
cache_size_ = frame_chunk_size_ - frame_chunk_stride_;
receptive_filed_length_ = opts.receptive_filed_length; receptive_filed_length_ = opts.receptive_filed_length;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim(); dim_ = base_extractor_->Dim();
@ -38,49 +40,83 @@ void Assembler::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
// pop feature chunk // pop feature chunk
bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) { bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
feats->Resize(dim_ * frame_chunk_size_);
bool result = Compute(feats); bool result = Compute(feats);
return result; return result;
} }
// read all data from base_feature_extractor_ into cache_ // read frame by frame from base_feature_extractor_ into cache_
bool Assembler::Compute(Vector<BaseFloat>* feats) { bool Assembler::Compute(Vector<BaseFloat>* feats) {
// compute and feed // compute and feed frame by frame
bool result = false;
while (feature_cache_.size() < frame_chunk_size_) { while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature; Vector<BaseFloat> feature;
result = base_extractor_->Read(&feature); bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) { if (result == false || feature.Dim() == 0) {
if (IsFinished() == false) return false; VLOG(1) << "result: " << result
<< " feature dim: " << feature.Dim();
if (IsFinished() == false) {
VLOG(1) << "finished reading feature. cache size: "
<< feature_cache_.size();
return false;
} else {
VLOG(1) << "break";
break; break;
} }
}
CHECK(feature.Dim() == dim_);
feature_cache_.push(feature); feature_cache_.push(feature);
nframes_ += 1;
VLOG(1) << "nframes: " << nframes_;
} }
if (feature_cache_.size() < receptive_filed_length_) { if (feature_cache_.size() < receptive_filed_length_) {
VLOG(1) << "feature_cache less than receptive_filed_lenght. "
<< feature_cache_.size() << ": " << receptive_filed_length_;
return false; return false;
} }
if (fill_zero_) {
while (feature_cache_.size() < frame_chunk_size_) { while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature(dim_, kaldi::kSetZero); Vector<BaseFloat> feature(dim_, kaldi::kSetZero);
nframes_ += 1;
feature_cache_.push(feature); feature_cache_.push(feature);
} }
}
int32 this_chunk_size =
std::min(static_cast<int32>(feature_cache_.size()), frame_chunk_size_);
feats->Resize(dim_ * this_chunk_size);
VLOG(1) << "read " << this_chunk_size << " feat.";
int32 counter = 0; int32 counter = 0;
int32 cache_size = frame_chunk_size_ - frame_chunk_stride_; while (counter < this_chunk_size) {
int32 elem_dim = base_extractor_->Dim();
while (counter < frame_chunk_size_) {
Vector<BaseFloat>& val = feature_cache_.front(); Vector<BaseFloat>& val = feature_cache_.front();
int32 start = counter * elem_dim; CHECK(val.Dim() == dim_) << val.Dim();
feats->Range(start, elem_dim).CopyFromVec(val);
if (frame_chunk_size_ - counter <= cache_size) { int32 start = counter * dim_;
feats->Range(start, dim_).CopyFromVec(val);
if (this_chunk_size - counter <= cache_size_) {
feature_cache_.push(val); feature_cache_.push(val);
} }
// val is reference, so we should pop here
feature_cache_.pop(); feature_cache_.pop();
counter++; counter++;
} }
CHECK(feature_cache_.size() == cache_size_);
return result; return true;
}
void Assembler::Reset() {
std::queue<kaldi::Vector<kaldi::BaseFloat>> empty;
std::swap(feature_cache_, empty);
nframes_ = 0;
base_extractor_->Reset();
} }
} // namespace ppspeech } // namespace ppspeech

@ -22,14 +22,11 @@ namespace ppspeech {
struct AssemblerOptions { struct AssemblerOptions {
// refer:https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/s2t/exps/deepspeech2/model.py // refer:https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/s2t/exps/deepspeech2/model.py
// the nnet batch forward // the nnet batch forward
int32 receptive_filed_length; int32 receptive_filed_length{1};
int32 subsampling_rate; int32 subsampling_rate{1};
int32 nnet_decoder_chunk; int32 nnet_decoder_chunk{1};
bool fill_zero{false}; // whether fill zero when last chunk is not equal to
AssemblerOptions() // frame_chunk_size_
: receptive_filed_length(1),
subsampling_rate(1),
nnet_decoder_chunk(1) {}
}; };
class Assembler : public FrontendInterface { class Assembler : public FrontendInterface {
@ -39,29 +36,34 @@ class Assembler : public FrontendInterface {
std::unique_ptr<FrontendInterface> base_extractor = NULL); std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves // Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs); void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) override;
// feats size = num_frames * feat_dim // feats size = num_frames * feat_dim
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats); bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) override;
// feat dim // feat dim
virtual size_t Dim() const { return dim_; } size_t Dim() const override { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); } void SetFinished() override { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } bool IsFinished() const override { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); } void Reset() override;
private: private:
bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats); bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats);
int32 dim_; bool fill_zero_{false};
int32 dim_; // feat dim
int32 frame_chunk_size_; // window int32 frame_chunk_size_; // window
int32 frame_chunk_stride_; // stride int32 frame_chunk_stride_; // stride
int32 cache_size_; // window - stride
int32 receptive_filed_length_; int32 receptive_filed_length_;
std::queue<kaldi::Vector<kaldi::BaseFloat>> feature_cache_; std::queue<kaldi::Vector<kaldi::BaseFloat>> feature_cache_;
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
int32 nframes_; // num frame computed
DISALLOW_COPY_AND_ASSIGN(Assembler); DISALLOW_COPY_AND_ASSIGN(Assembler);
}; };

@ -13,13 +13,14 @@
// limitations under the License. // limitations under the License.
#include "frontend/audio/audio_cache.h" #include "frontend/audio/audio_cache.h"
#include "kaldi/base/timer.h" #include "kaldi/base/timer.h"
namespace ppspeech { namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::VectorBase;
using kaldi::Vector; using kaldi::Vector;
using kaldi::VectorBase;
AudioCache::AudioCache(int buffer_size, bool to_float32) AudioCache::AudioCache(int buffer_size, bool to_float32)
: finished_(false), : finished_(false),
@ -83,6 +84,10 @@ bool AudioCache::Read(Vector<BaseFloat>* waves) {
} }
size_ -= chunk_size; size_ -= chunk_size;
offset_ = (offset_ + chunk_size) % ring_buffer_.size(); offset_ = (offset_ + chunk_size) % ring_buffer_.size();
nsamples_ += chunk_size;
VLOG(1) << "nsamples readed: " << nsamples_;
ready_feed_condition_.notify_one(); ready_feed_condition_.notify_one();
return true; return true;
} }

@ -41,10 +41,11 @@ class AudioCache : public FrontendInterface {
virtual bool IsFinished() const { return finished_; } virtual bool IsFinished() const { return finished_; }
virtual void Reset() { void Reset() override {
offset_ = 0; offset_ = 0;
size_ = 0; size_ = 0;
finished_ = false; finished_ = false;
nsamples_ = 0;
} }
private: private:
@ -61,6 +62,7 @@ class AudioCache : public FrontendInterface {
kaldi::int32 timeout_; // millisecond kaldi::int32 timeout_; // millisecond
bool to_float32_; // int16 -> float32. used in linear_spectrogram bool to_float32_; // int16 -> float32. used in linear_spectrogram
int32 nsamples_; // number samples readed.
DISALLOW_COPY_AND_ASSIGN(AudioCache); DISALLOW_COPY_AND_ASSIGN(AudioCache);
}; };

@ -14,22 +14,25 @@
#include "frontend/audio/cmvn.h" #include "frontend/audio/cmvn.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor) CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) { : var_norm_(true) {
CHECK_NE(cmvn_file, "");
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
bool binary; bool binary;
kaldi::Input ki(cmvn_file, &binary); kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary); stats_.Read(ki.Stream(), binary);
@ -55,11 +58,11 @@ bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
// feats contain num_frames feature. // feats contain num_frames feature.
void CMVN::Compute(VectorBase<BaseFloat>* feats) const { void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
KALDI_ASSERT(feats != NULL); KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
feats->Dim() % dim != 0) { feats->Dim() % dim_ != 0) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x' KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << ','
<< stats_.NumCols() << ", feats " << feats->Dim() << 'x'; << stats_.NumCols() - 1 << ", feats " << feats->Dim() << 'x';
} }
if (stats_.NumRows() == 1 && var_norm_) { if (stats_.NumRows() == 1 && var_norm_) {
KALDI_ERR KALDI_ERR
@ -67,7 +70,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
<< "are supplied."; << "are supplied.";
} }
double count = stats_(0, dim); double count = stats_(0, dim_);
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one. // computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0) if (count < 1.0)
@ -77,14 +80,14 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
if (!var_norm_) { if (!var_norm_) {
Vector<BaseFloat> offset(feats->Dim()); Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim); SubVector<double> mean_stats(stats_.RowData(0), dim_);
Vector<double> mean_stats_apply(feats->Dim()); Vector<double> mean_stats_apply(feats->Dim());
// fill the datat of mean_stats in mean_stats_appy whose dim is equal // fill the datat of mean_stats in mean_stats_appy whose dim_ is equal
// with the dim of feature. // with the dim_ of feature.
// the dim of feats = dim * num_frames; // the dim_ of feats = dim_ * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) { for (int32 idx = 0; idx < feats->Dim() / dim_; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx, SubVector<double> stats_tmp(mean_stats_apply.Data() + dim_ * idx,
dim); dim_);
stats_tmp.CopyFromVec(mean_stats); stats_tmp.CopyFromVec(mean_stats);
} }
offset.AddVec(-1.0 / count, mean_stats_apply); offset.AddVec(-1.0 / count, mean_stats_apply);
@ -94,7 +97,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
// norm(0, d) = mean offset; // norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim()); kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) { for (int32 d = 0; d < dim_; d++) {
double mean, offset, scale; double mean, offset, scale;
mean = stats_(0, d) / count; mean = stats_(0, d) / count;
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20; double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
@ -111,7 +114,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
for (int32 d_skip = d; d_skip < feats->Dim();) { for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset; norm(0, d_skip) = offset;
norm(1, d_skip) = scale; norm(1, d_skip) = scale;
d_skip = d_skip + dim; d_skip = d_skip + dim_;
} }
} }
// Apply the normalization. // Apply the normalization.

@ -30,8 +30,11 @@ DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace boost::json; // from <boost/json.hpp> using namespace boost::json; // from <boost/json.hpp>
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file; LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
@ -44,13 +47,13 @@ int main(int argc, char* argv[]) {
for (auto obj : value.as_object()) { for (auto obj : value.as_object()) {
if (obj.key() == "mean_stat") { if (obj.key() == "mean_stat") {
LOG(INFO) << "mean_stat:" << obj.value(); VLOG(2) << "mean_stat:" << obj.value();
} }
if (obj.key() == "var_stat") { if (obj.key() == "var_stat") {
LOG(INFO) << "var_stat: " << obj.value(); VLOG(2) << "var_stat: " << obj.value();
} }
if (obj.key() == "frame_num") { if (obj.key() == "frame_num") {
LOG(INFO) << "frame_num: " << obj.value(); VLOG(2) << "frame_num: " << obj.value();
} }
} }
@ -76,7 +79,7 @@ int main(int argc, char* argv[]) {
cmvn_stats(1, idx) = var_stat_vec[idx]; cmvn_stats(1, idx) = var_stat_vec[idx];
} }
cmvn_stats(0, mean_size) = frame_num; cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats; VLOG(2) << cmvn_stats;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;

@ -16,29 +16,36 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h" #include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h" #include "frontend/audio/fbank.h"
#include "frontend/audio/feature_cache.h" #include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "", "read cmvn"); DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(num_bins, 161, "fbank num bins"); DEFINE_int32(num_bins, 161, "fbank num bins");
DEFINE_int32(sample_rate, 16000, "sampe rate: 16k, 8k.");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
CHECK_GT(FLAGS_feature_wspecifier.size(), 0);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier); FLAGS_wav_rspecifier);
kaldi::SequentialTableReader<kaldi::WaveInfoHolder> wav_info_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
@ -54,6 +61,10 @@ int main(int argc, char* argv[]) {
opt.frame_opts.frame_shift_ms = 10; opt.frame_opts.frame_shift_ms = 10;
opt.mel_opts.num_bins = FLAGS_num_bins; opt.mel_opts.num_bins = FLAGS_num_bins;
opt.frame_opts.dither = 0.0; opt.frame_opts.dither = 0.0;
LOG(INFO) << "frame_length_ms: " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame_shift_ms: " << opt.frame_opts.frame_shift_ms;
LOG(INFO) << "num_bins: " << opt.mel_opts.num_bins;
LOG(INFO) << "dither: " << opt.frame_opts.dither;
std::unique_ptr<ppspeech::FrontendInterface> fbank( std::unique_ptr<ppspeech::FrontendInterface> fbank(
new ppspeech::Fbank(opt, std::move(data_source))); new ppspeech::Fbank(opt, std::move(data_source)));
@ -61,53 +72,76 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> cmvn( std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
ppspeech::FeatureCacheOptions feat_cache_opts;
// the feature cache output feature chunk by chunk. // the feature cache output feature chunk by chunk.
ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "fbank: " << true; LOG(INFO) << "fbank: " << true;
LOG(INFO) << "feat dim: " << feature_cache.Dim(); LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk; float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate; int chunk_sample_size = streaming_chunk * FLAGS_sample_rate;
LOG(INFO) << "sr: " << sample_rate; LOG(INFO) << "sr: " << FLAGS_sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (sec): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size; LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done(); wav_reader.Next()) { for (; !wav_reader.Done() && !wav_info_reader.Done();
std::string utt = wav_reader.Key(); wav_reader.Next(), wav_info_reader.Next()) {
const std::string& utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value(); const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "process utt: " << utt;
const std::string& utt2 = wav_info_reader.Key();
const kaldi::WaveInfo& wave_info = wav_info_reader.Value();
CHECK(utt == utt2)
<< "wav reader and wav info reader using diff rspecifier!!!";
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "samples: " << wave_info.SampleCount();
LOG(INFO) << "dur: " << wave_info.Duration() << " sec";
CHECK(wave_info.SampFreq() == FLAGS_sample_rate)
<< "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq();
// load first channel wav
int32 this_channel = 0; int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel); this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
// compute feat chunk by chunk
int tot_samples = waveform.Dim();
int sample_offset = 0; int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats; std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0; int feature_rows = 0;
while (sample_offset < tot_samples) { while (sample_offset < tot_samples) {
// cur chunk size
int cur_chunk_size = int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset); std::min(chunk_sample_size, tot_samples - sample_offset);
// get chunk wav
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size); kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) { for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i); wav_chunk(i) = waveform(sample_offset + i);
} }
kaldi::Vector<BaseFloat> features; // compute feat
feature_cache.Accept(wav_chunk); feature_cache.Accept(wav_chunk);
// send finish signal
if (cur_chunk_size < chunk_sample_size) { if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished(); feature_cache.SetFinished();
} }
// read feat
kaldi::Vector<BaseFloat> features;
bool flag = true; bool flag = true;
do { do {
flag = feature_cache.Read(&features); flag = feature_cache.Read(&features);
if (flag && features.Dim() != 0) {
feats.push_back(features); feats.push_back(features);
feature_rows += features.Dim() / feature_cache.Dim(); feature_rows += features.Dim() / feature_cache.Dim();
}
} while (flag == true && features.Dim() != 0); } while (flag == true && features.Dim() != 0);
// forward offset
sample_offset += cur_chunk_size; sample_offset += cur_chunk_size;
} }
@ -125,14 +159,20 @@ int main(int argc, char* argv[]) {
++cur_idx; ++cur_idx;
} }
} }
LOG(INFO) << "feat shape: " << features.NumRows() << " , "
<< features.NumCols();
feat_writer.Write(utt, features); feat_writer.Write(utt, features);
// reset frontend pipeline state
feature_cache.Reset(); feature_cache.Reset();
if (num_done % 50 == 0 && num_done != 0) if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances"; VLOG(2) << "Processed " << num_done << " utterances";
num_done++; num_done++;
} }
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
} }

@ -14,16 +14,15 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h" #include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h" #include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
@ -31,8 +30,11 @@ DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier); FLAGS_wav_rspecifier);

@ -27,7 +27,7 @@ namespace ppspeech {
// pre-recorded audio/feature // pre-recorded audio/feature
class DataCache : public FrontendInterface { class DataCache : public FrontendInterface {
public: public:
explicit DataCache() { finished_ = false; } DataCache() { finished_ = false; }
// accept waves/feats // accept waves/feats
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) { virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
@ -56,4 +56,4 @@ class DataCache : public FrontendInterface {
DISALLOW_COPY_AND_ASSIGN(DataCache); DISALLOW_COPY_AND_ASSIGN(DataCache);
}; };
} } // namespace ppspeech

@ -14,17 +14,18 @@
#include "frontend/audio/db_norm.h" #include "frontend/audio/db_norm.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
DecibelNormalizer::DecibelNormalizer( DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts, const DecibelNormalizerOptions& opts,

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "frontend/audio/fbank.h" #include "frontend/audio/fbank.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -20,12 +21,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
FbankComputer::FbankComputer(const Options& opts) FbankComputer::FbankComputer(const Options& opts)

@ -16,12 +16,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
FeatureCache::FeatureCache(FeatureCacheOptions opts, FeatureCache::FeatureCache(FeatureCacheOptions opts,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
@ -73,6 +73,9 @@ bool FeatureCache::Compute() {
if (result == false || feature.Dim() == 0) return false; if (result == false || feature.Dim() == 0) return false;
int32 num_chunk = feature.Dim() / dim_; int32 num_chunk = feature.Dim() / dim_;
nframe_ += num_chunk;
VLOG(1) << "nframe computed: " << nframe_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_; int32 start = chunk_idx * dim_;
Vector<BaseFloat> feature_chunk(dim_); Vector<BaseFloat> feature_chunk(dim_);

@ -41,21 +41,24 @@ class FeatureCache : public FrontendInterface {
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { virtual void SetFinished() {
LOG(INFO) << "set finished";
// std::unique_lock<std::mutex> lock(mutex_); // std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished(); base_extractor_->SetFinished();
LOG(INFO) << "set finished";
// read the last chunk data // read the last chunk data
Compute(); Compute();
// ready_feed_condition_.notify_one(); // ready_feed_condition_.notify_one();
LOG(INFO) << "compute last feats done.";
} }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { void Reset() override {
std::queue<kaldi::Vector<BaseFloat>> empty;
std::swap(cache_, empty);
nframe_ = 0;
base_extractor_->Reset(); base_extractor_->Reset();
while (!cache_.empty()) { VLOG(1) << "feature cache reset: cache size: " << cache_.size();
cache_.pop();
}
} }
private: private:
@ -74,6 +77,7 @@ class FeatureCache : public FrontendInterface {
std::condition_variable ready_feed_condition_; std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_; std::condition_variable ready_read_condition_;
int32 nframe_; // num of feature computed
DISALLOW_COPY_AND_ASSIGN(FeatureCache); DISALLOW_COPY_AND_ASSIGN(FeatureCache);
}; };

@ -18,7 +18,8 @@ namespace ppspeech {
using std::unique_ptr; using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
: opts_(opts) {
unique_ptr<FrontendInterface> data_source( unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); new ppspeech::AudioCache(1000 * kint16max, opts.to_float32));
@ -32,6 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
opts.linear_spectrogram_opts, std::move(data_source))); opts.linear_spectrogram_opts, std::move(data_source)));
} }
CHECK_NE(opts.cmvn_file, "");
unique_ptr<FrontendInterface> cmvn( unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));
@ -42,4 +44,4 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); new ppspeech::Assembler(opts.assembler_opts, std::move(cache)));
} }
} // ppspeech } // namespace ppspeech

@ -25,27 +25,78 @@
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
// feature
DECLARE_bool(use_fbank);
DECLARE_bool(fill_zero);
DECLARE_int32(num_bins);
DECLARE_string(cmvn_file);
// feature sliding window
DECLARE_int32(receptive_field_length);
DECLARE_int32(subsampling_rate);
DECLARE_int32(nnet_decoder_chunk);
namespace ppspeech { namespace ppspeech {
struct FeaturePipelineOptions { struct FeaturePipelineOptions {
std::string cmvn_file; std::string cmvn_file{};
bool to_float32; // true, only for linear feature bool to_float32{false}; // true, only for linear feature
bool use_fbank; bool use_fbank{true};
LinearSpectrogramOptions linear_spectrogram_opts; LinearSpectrogramOptions linear_spectrogram_opts{};
kaldi::FbankOptions fbank_opts; kaldi::FbankOptions fbank_opts{};
FeatureCacheOptions feature_cache_opts; FeatureCacheOptions feature_cache_opts{};
AssemblerOptions assembler_opts; AssemblerOptions assembler_opts{};
FeaturePipelineOptions() static FeaturePipelineOptions InitFromFlags() {
: cmvn_file(""), FeaturePipelineOptions opts;
to_float32(false), // true, only for linear feature opts.cmvn_file = FLAGS_cmvn_file;
use_fbank(true), LOG(INFO) << "cmvn file: " << opts.cmvn_file;
linear_spectrogram_opts(),
fbank_opts(), // frame options
feature_cache_opts(), kaldi::FrameExtractionOptions frame_opts;
assembler_opts() {} frame_opts.dither = 0.0;
LOG(INFO) << "dither: " << frame_opts.dither;
frame_opts.frame_shift_ms = 10;
LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms;
opts.use_fbank = FLAGS_use_fbank;
LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear");
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
opts.fbank_opts.frame_opts = frame_opts;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms;
// assembler opts
opts.assembler_opts.subsampling_rate = FLAGS_subsampling_rate;
opts.assembler_opts.receptive_filed_length =
FLAGS_receptive_field_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
opts.assembler_opts.fill_zero = FLAGS_fill_zero;
LOG(INFO) << "subsampling rate: "
<< opts.assembler_opts.subsampling_rate;
LOG(INFO) << "nnet receptive filed length: "
<< opts.assembler_opts.receptive_filed_length;
LOG(INFO) << "nnet chunk size: "
<< opts.assembler_opts.nnet_decoder_chunk;
LOG(INFO) << "frontend fill zeros: " << opts.assembler_opts.fill_zero;
return opts;
}
}; };
class FeaturePipeline : public FrontendInterface { class FeaturePipeline : public FrontendInterface {
public: public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts); explicit FeaturePipeline(const FeaturePipelineOptions& opts);
@ -60,7 +111,21 @@ class FeaturePipeline : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); } virtual void Reset() { base_extractor_->Reset(); }
const FeaturePipelineOptions& Config() { return opts_; }
const BaseFloat FrameShift() const {
return opts_.fbank_opts.frame_opts.frame_shift_ms;
}
const BaseFloat FrameLength() const {
return opts_.fbank_opts.frame_opts.frame_length_ms;
}
const BaseFloat SampleRate() const {
return opts_.fbank_opts.frame_opts.samp_freq;
}
private: private:
FeaturePipelineOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
}; };
}
} // namespace ppspeech

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -20,12 +21,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts) LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts)

@ -14,6 +14,7 @@
#include "frontend/audio/mfcc.h" #include "frontend/audio/mfcc.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -21,12 +22,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
Mfcc::Mfcc(const MfccOptions& opts, Mfcc::Mfcc(const MfccOptions& opts,

@ -14,7 +14,6 @@
#pragma once #pragma once
#include "kaldi/feat/feature-mfcc.h"
#include "kaldi/feat/feature-mfcc.h" #include "kaldi/feat/feature-mfcc.h"
#include "kaldi/matrix/kaldi-vector.h" #include "kaldi/matrix/kaldi-vector.h"

@ -101,7 +101,9 @@ namespace kaldi {
*/ */
class DecodableInterface { class DecodableInterface {
public: public:
/// Returns the log likelihood, which will be negated in the decoder. virtual ~DecodableInterface() {}
/// Returns the log likelihood(logprob), which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > /// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame /// frame
/// before calling this. /// before calling this.
@ -143,11 +145,12 @@ class DecodableInterface {
/// this is for compatibility with OpenFst). /// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0; virtual int32 NumIndices() const = 0;
/// Returns the likelihood(prob), which will be postive in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
virtual bool FrameLikelihood( virtual bool FrameLikelihood(
int32 frame, std::vector<kaldi::BaseFloat>* likelihood) = 0; int32 frame, std::vector<kaldi::BaseFloat>* likelihood) = 0;
virtual ~DecodableInterface() {}
}; };
/// @} /// @}
} // namespace Kaldi } // namespace Kaldi

@ -1,14 +1,39 @@
project(nnet) set(srcs decodable.cc)
add_library(nnet STATIC if(USING_DS2)
decodable.cc list(APPEND srcs ds2_nnet.cc)
paddle_nnet.cc endif()
)
if(USING_U2)
list(APPEND srcs u2_nnet.cc)
endif()
add_library(nnet STATIC ${srcs})
target_link_libraries(nnet absl::strings) target_link_libraries(nnet absl::strings)
set(bin_name nnet_forward_main) if(USING_U2)
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
endif()
if(USING_DS2)
set(bin_name ds2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet ${DEPS}) target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
target_link_libraries(${bin_name} ${DEPS})
endif()
# test bin
if(USING_U2)
set(bin_name u2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endif()

@ -18,10 +18,10 @@ namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector;
using kaldi::Vector; using kaldi::Vector;
using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet, Decodable::Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend, const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale) kaldi::BaseFloat acoustic_scale)
: frontend_(frontend), : frontend_(frontend),
@ -30,17 +30,17 @@ Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
frames_ready_(0), frames_ready_(0),
acoustic_scale_(acoustic_scale) {} acoustic_scale_(acoustic_scale) {}
// for debug
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_cache_ = likelihood; nnet_out_cache_ = likelihood;
frames_ready_ += likelihood.NumRows(); frames_ready_ += likelihood.NumRows();
} }
// Decodable::Init(DecodableConfig config) {
//}
// return the size of frame have computed. // return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; } int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1; // frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) { bool Decodable::IsLastFrame(int32 frame) {
bool flag = EnsureFrameHaveComputed(frame); bool flag = EnsureFrameHaveComputed(frame);
@ -53,18 +53,9 @@ int32 Decodable::NumIndices() const { return 0; }
// id. // id.
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; } int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols());
CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_;
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
return acoustic_scale_ *
std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
}
bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {
// decoding frame
if (frame >= frames_ready_) { if (frame >= frames_ready_) {
return AdvanceChunk(); return AdvanceChunk();
} }
@ -72,38 +63,112 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
} }
bool Decodable::AdvanceChunk() { bool Decodable::AdvanceChunk() {
kaldi::Timer timer;
// read feats
Vector<BaseFloat> features; Vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) { if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(1) << "decodable exit;";
return false; return false;
} }
int32 nnet_dim = 0; VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats.";
Vector<BaseFloat> inferences;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); // forward feats
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); NnetOut out;
nnet_cache_.CopyRowsFromVec(inferences); nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
Vector<BaseFloat>& logprobs = out.logprobs;
VLOG(2) << "Forward out " << logprobs.Dim() / vocab_dim
<< " decoder frames.";
// cache nnet outupts
nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim);
nnet_out_cache_.CopyRowsFromVec(logprobs);
// update state, decoding frame.
frame_offset_ = frames_ready_; frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_out_cache_.NumRows();
VLOG(2) << "Forward feat chunk cost: " << timer.Elapsed() << " sec.";
return true;
}
bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
int* vocab_dim) {
if (AdvanceChunk() == false) {
return false;
}
int nrows = nnet_out_cache_.NumRows();
CHECK(nrows == (frames_ready_ - frame_offset_));
if (nrows <= 0) {
LOG(WARNING) << "No new nnet out in cache.";
return false;
}
logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols());
logprobs->CopyRowsFromMat(nnet_out_cache_);
*vocab_dim = nnet_out_cache_.NumCols();
return true; return true;
} }
// read one frame likelihood
bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) { bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
std::vector<BaseFloat> result; if (EnsureFrameHaveComputed(frame) == false) {
if (EnsureFrameHaveComputed(frame) == false) return false; VLOG(1) << "framelikehood exit.";
likelihood->resize(nnet_cache_.NumCols()); return false;
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { }
int nrows = nnet_out_cache_.NumRows();
CHECK(nrows == (frames_ready_ - frame_offset_));
int vocab_size = nnet_out_cache_.NumCols();
likelihood->resize(vocab_size);
for (int32 idx = 0; idx < vocab_size; ++idx) {
(*likelihood)[idx] = (*likelihood)[idx] =
nnet_cache_(frame - frame_offset_, idx) * acoustic_scale_; nnet_out_cache_(frame - frame_offset_, idx) * acoustic_scale_;
VLOG(4) << "nnet out: " << frame << " offset:" << frame_offset_ << " "
<< nnet_out_cache_.NumRows()
<< " logprob: " << nnet_out_cache_(frame - frame_offset_, idx);
} }
return true; return true;
} }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
if (EnsureFrameHaveComputed(frame) == false) {
return false;
}
CHECK_LE(index, nnet_out_cache_.NumCols());
CHECK_LE(frame, frames_ready_);
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
BaseFloat logprob = 0.0;
int32 frame_idx = frame - frame_offset_;
BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index));
if (nnet_->IsLogProb()) {
logprob = nnet_out;
} else {
logprob = std::log(nnet_out + std::numeric_limits<float>::epsilon());
}
CHECK(!std::isnan(logprob) && !std::isinf(logprob));
return acoustic_scale_ * logprob;
}
void Decodable::Reset() { void Decodable::Reset() {
if (frontend_ != nullptr) frontend_->Reset(); if (frontend_ != nullptr) frontend_->Reset();
if (nnet_ != nullptr) nnet_->Reset(); if (nnet_ != nullptr) nnet_->Reset();
frame_offset_ = 0; frame_offset_ = 0;
frames_ready_ = 0; frames_ready_ = 0;
nnet_cache_.Resize(0, 0); nnet_out_cache_.Resize(0, 0);
}
void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
} }
} // namespace ppspeech } // namespace ppspeech

@ -24,38 +24,68 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet, explicit Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend, const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale = 1.0); kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config); // void Init(DecodableOpts config);
// nnet logprob output, used by wfst
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame);
virtual int32 NumIndices() const; // nnet output
// not logprob
virtual bool FrameLikelihood(int32 frame, virtual bool FrameLikelihood(int32 frame,
std::vector<kaldi::BaseFloat>* likelihood); std::vector<kaldi::BaseFloat>* likelihood);
// forward nnet with feats
bool AdvanceChunk();
// forward nnet with feats, and get nnet output
bool AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
int* vocab_dim);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score);
virtual bool IsLastFrame(int32 frame);
// nnet output dim, e.g. vocab size
virtual int32 NumIndices() const;
virtual int32 NumFramesReady() const; virtual int32 NumFramesReady() const;
// for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
void Reset(); void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); } bool IsInputFinished() const { return frontend_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame); bool EnsureFrameHaveComputed(int32 frame);
int32 TokenId2NnetId(int32 token_id); int32 TokenId2NnetId(int32 token_id);
std::shared_ptr<NnetBase> Nnet() { return nnet_; }
// for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
private: private:
bool AdvanceChunk();
std::shared_ptr<FrontendInterface> frontend_; std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetBase> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// nnet outputs' cache
kaldi::Matrix<kaldi::BaseFloat> nnet_out_cache_;
// the frame is nnet prob frame rather than audio feature frame // the frame is nnet prob frame rather than audio feature frame
// nnet frame subsample the feature frame // nnet frame subsample the feature frame
// eg: 35 frame features output 8 frame inferences // eg: 35 frame features output 8 frame inferences
int32 frame_offset_; int32 frame_offset_;
int32 frames_ready_; int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame // todo: feature frame mismatch with nnet inference frame
// so use subsampled_frame // so use subsampled_frame
int32 current_log_post_subsampled_offset_; int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_; int32 num_chunk_computed_;
kaldi::BaseFloat acoustic_scale_; kaldi::BaseFloat acoustic_scale_;
}; };

@ -12,16 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "nnet/paddle_nnet.h" #include "nnet/ds2_nnet.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
namespace ppspeech { namespace ppspeech {
using std::vector;
using std::string;
using std::shared_ptr;
using kaldi::Matrix; using kaldi::Matrix;
using kaldi::Vector; using kaldi::Vector;
using std::shared_ptr;
using std::string;
using std::vector;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names; std::vector<std::string> cache_names;
@ -48,6 +49,7 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
} }
PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) { PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
subsampling_rate_ = opts.subsample_rate;
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(opts.model_path, opts.param_path); config.SetModel(opts.model_path, opts.param_path);
if (opts.use_gpu) { if (opts.use_gpu) {
@ -143,9 +145,8 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
} }
void PaddleNnet::FeedForward(const Vector<BaseFloat>& features, void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
int32 feature_dim, const int32& feature_dim,
Vector<BaseFloat>* inferences, NnetOut* out) {
int32* inference_dim) {
paddle_infer::Predictor* predictor = GetPredictor(); paddle_infer::Predictor* predictor = GetPredictor();
int feat_row = features.Dim() / feature_dim; int feat_row = features.Dim() / feature_dim;
@ -203,9 +204,13 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
std::vector<int> output_shape = output_tensor->shape(); std::vector<int> output_shape = output_tensor->shape();
int32 row = output_shape[1]; int32 row = output_shape[1];
int32 col = output_shape[2]; int32 col = output_shape[2];
inferences->Resize(row * col);
*inference_dim = col;
output_tensor->CopyToCpu(inferences->Data()); // inferences->Resize(row * col);
// *inference_dim = col;
out->logprobs.Resize(row * col);
out->vocab_dim = col;
output_tensor->CopyToCpu(out->logprobs.Data());
ReleasePredictor(predictor); ReleasePredictor(predictor);
} }

@ -13,64 +13,20 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <numeric> #include <numeric>
#include "base/common.h" #include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
#include "nnet/nnet_itf.h" #include "nnet/nnet_itf.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
namespace ppspeech { namespace ppspeech {
struct ModelOptions {
std::string model_path;
std::string param_path;
int thread_num; // predictor thread pool size
bool use_gpu;
bool switch_ir_optim;
std::string input_names;
std::string output_names;
std::string cache_names;
std::string cache_shape;
bool enable_fc_padding;
bool enable_profile;
ModelOptions()
: model_path(""),
param_path(""),
thread_num(2),
use_gpu(false),
input_names(""),
output_names(""),
cache_names(""),
cache_shape(""),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path");
opts->Register("model-param", &param_path, "params model file path");
opts->Register("thread-num", &thread_num, "thread num");
opts->Register("use-gpu", &use_gpu, "if use gpu");
opts->Register("input-names", &input_names, "paddle input names");
opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom",
&switch_ir_optim,
"paddle SwitchIrOptim option");
opts->Register("enable-fc-padding",
&enable_fc_padding,
"paddle EnableFCPadding option");
opts->Register(
"enable-profile", &enable_profile, "paddle EnableProfile option");
}
};
template <typename T> template <typename T>
class Tensor { class Tensor {
public: public:
Tensor() {} Tensor() {}
Tensor(const std::vector<int>& shape) : _shape(shape) { explicit Tensor(const std::vector<int>& shape) : _shape(shape) {
int neml = std::accumulate( int neml = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>()); _shape.begin(), _shape.end(), 1, std::multiplies<int>());
LOG(INFO) << "Tensor neml: " << neml; LOG(INFO) << "Tensor neml: " << neml;
@ -92,21 +48,35 @@ class Tensor {
std::vector<T> _data; std::vector<T> _data;
}; };
class PaddleNnet : public NnetInterface { class PaddleNnet : public NnetBase {
public: public:
PaddleNnet(const ModelOptions& opts); explicit PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
int32 feature_dim, const int32& feature_dim,
kaldi::Vector<kaldi::BaseFloat>* inferences, NnetOut* out) override;
int32* inference_dim);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override {
VLOG(2) << "deepspeech2 not has AttentionRescoring.";
}
void Dim(); void Dim();
virtual void Reset();
void Reset() override;
bool IsLogProb() override { return false; }
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder( std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name); const std::string& name);
void InitCacheEncouts(const ModelOptions& opts); void InitCacheEncouts(const ModelOptions& opts);
void EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out)
const override {}
private: private:
paddle_infer::Predictor* GetPredictor(); paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor); int ReleasePredictor(paddle_infer::Predictor* predictor);
@ -117,6 +87,7 @@ class PaddleNnet : public NnetInterface {
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id; std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
std::map<std::string, int> cache_names_idx_; std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_; std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
ModelOptions opts_; ModelOptions opts_;
public: public:

@ -12,45 +12,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "base/flags.h" #include "base/common.h"
#include "base/log.h" #include "decoder/param.h"
#include "frontend/audio/assembler.h" #include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
@ -62,13 +44,8 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
model_opts.model_path = model_graph;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache()); std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
@ -76,8 +53,8 @@ int main(int argc, char* argv[]) {
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;
@ -146,7 +123,7 @@ int main(int argc, char* argv[]) {
} }
kaldi::Matrix<kaldi::BaseFloat> result(prob_vec.size(), kaldi::Matrix<kaldi::BaseFloat> result(prob_vec.size(),
prob_vec[0].Dim()); prob_vec[0].Dim());
for (int32 row_idx = 0; row_idx < prob_vec.size(); ++row_idx) { for (int row_idx = 0; row_idx < prob_vec.size(); ++row_idx) {
for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) { for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) {
result(row_idx, col_idx) = prob_vec[row_idx](col_idx); result(row_idx, col_idx) = prob_vec[row_idx](col_idx);
} }

@ -11,24 +11,110 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "base/basic_types.h" #include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h" #include "kaldi/base/kaldi-types.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
DECLARE_int32(subsampling_rate);
DECLARE_string(model_path);
DECLARE_string(param_path);
DECLARE_string(model_input_names);
DECLARE_string(model_output_names);
DECLARE_string(model_cache_names);
DECLARE_string(model_cache_shapes);
namespace ppspeech { namespace ppspeech {
struct ModelOptions {
// common
int subsample_rate{1};
int thread_num{1}; // predictor thread pool size for ds2;
bool use_gpu{false};
std::string model_path;
std::string param_path;
// ds2 for inference
std::string input_names{};
std::string output_names{};
std::string cache_names{};
std::string cache_shape{};
bool switch_ir_optim{false};
bool enable_fc_padding{false};
bool enable_profile{false};
static ModelOptions InitFromFlags() {
ModelOptions opts;
opts.subsample_rate = FLAGS_subsampling_rate;
LOG(INFO) << "subsampling rate: " << opts.subsample_rate;
opts.model_path = FLAGS_model_path;
LOG(INFO) << "model path: " << opts.model_path;
opts.param_path = FLAGS_param_path;
LOG(INFO) << "param path: " << opts.param_path;
LOG(INFO) << "DS2 param: ";
opts.cache_names = FLAGS_model_cache_names;
LOG(INFO) << " cache names: " << opts.cache_names;
opts.cache_shape = FLAGS_model_cache_shapes;
LOG(INFO) << " cache shape: " << opts.cache_shape;
opts.input_names = FLAGS_model_input_names;
LOG(INFO) << " input names: " << opts.input_names;
opts.output_names = FLAGS_model_output_names;
LOG(INFO) << " output names: " << opts.output_names;
return opts;
}
};
struct NnetOut {
// nnet out. maybe logprob or prob. Almost time this is logprob.
kaldi::Vector<kaldi::BaseFloat> logprobs;
int32 vocab_dim;
// nnet state. Only using in Attention model.
std::vector<std::vector<kaldi::BaseFloat>> encoder_outs;
NnetOut() : logprobs({}), vocab_dim(-1), encoder_outs({}) {}
};
class NnetInterface { class NnetInterface {
public: public:
virtual ~NnetInterface() {}
// forward feat with nnet.
// nnet do not cache feats, feats cached by frontend.
// nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache,
// frame_offset.
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
int32 feature_dim, const int32& feature_dim,
kaldi::Vector<kaldi::BaseFloat>* inferences, NnetOut* out) = 0;
int32* inference_dim) = 0;
virtual void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) = 0;
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
virtual void Reset() = 0; virtual void Reset() = 0;
virtual ~NnetInterface() {}
// true, nnet output is logprob; otherwise is prob,
virtual bool IsLogProb() = 0;
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;
};
class NnetBase : public NnetInterface {
public:
int SubsamplingRate() const { return subsampling_rate_; }
protected:
int subsampling_rate_{1};
}; };
} // namespace ppspeech } // namespace ppspeech

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save