Merge pull request #3197 from PaddlePaddle/speechx

[engine] merge speechx
pull/3200/head
Hui Zhang 2 years ago committed by GitHub
commit 7cab869d63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -136,7 +136,7 @@ pull_request_rules:
add: ["Docker"]
- name: "auto add label=Deployment"
conditions:
- files~=^speechx/
- files~=^runtime/
actions:
label:
add: ["Deployment"]

@ -3,8 +3,12 @@ repos:
rev: v0.16.0
hooks:
- id: yapf
files: \.py$
exclude: (?=third_party).*(\.py)$
name: yapf
language: python
entry: yapf
args: [-i, -vv]
types: [python]
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: a11d9314b22d8f8c7556443875b731ef05965464
@ -31,7 +35,7 @@ repos:
- --ignore=E501,E228,E226,E261,E266,E128,E402,W503
- --builtins=G,request
- --jobs=1
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$
- repo : https://github.com/Lucas-C/pre-commit-hooks
rev: v1.0.1
@ -53,16 +57,16 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
- id: cpplint
name: cpplint
description: Static code analysis of C/C++ files
language: python
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
exclude: (?=runtime/engine/kaldi|runtime/engine/common/matrix|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:
- id: reorder-python-imports
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$

@ -193,7 +193,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation.
- 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](./demos/speech_ssl), Support ASR and Feature Extraction.
- 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660).
- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](./speechx/examples/u2pp_ol/wenetspeech).
- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/runtime/examples/u2pp_ol/wenetspeech).
- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3).
- 🔥 2022.10.26: Add [Prosody Prediction](./examples/other/rhy) for TTS.
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.

@ -0,0 +1,7 @@
engine/common/base/flags.h
engine/common/base/log.h
tools/valgrind*
*log
fc_patch/*
test

@ -0,0 +1,211 @@
# >=3.17 support -DCMAKE_FIND_DEBUG_MODE=ON
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0077.cmake")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
include(system)
project(paddlespeech VERSION 0.1)
set(PPS_VERSION_MAJOR 1)
set(PPS_VERSION_MINOR 0)
set(PPS_VERSION_PATCH 0)
set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}")
# compiler option
# Keep the same with openfst, -fPIC or -fpic
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_FIND_DEBUG_MODE OFF)
set(PPS_CXX_STANDARD 14)
# set std-14
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
# Ninja Generator will set CMAKE_BUILD_TYPE to Debug
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE)
endif()
# find_* e.g. find_library work when Cross-Compiling
if(ANDROID)
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
endif()
if(BUILD_IN_MACOS)
add_definitions("-DOS_MACOSX")
endif()
# install dir into `build/install`
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install)
include(FetchContent)
include(ExternalProject)
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
###############################################################################
# Option Configurations
###############################################################################
# https://github.com/google/brotli/pull/655
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_PPS_DEBUG "debug option" OFF)
if (WITH_PPS_DEBUG)
add_definitions("-DPPS_DEBUG")
endif()
option(WITH_ASR "build asr" ON)
option(WITH_CLS "build cls" ON)
option(WITH_VAD "build vad" ON)
option(WITH_GPU "NNet using GPU." OFF)
option(WITH_PROFILING "enable c++ profling" OFF)
option(WITH_TESTING "unit test" ON)
option(WITH_ONNX "u2 support onnx runtime" OFF)
###############################################################################
# Include Third Party
###############################################################################
include(gflags)
include(glog)
include(pybind)
#onnx
if(WITH_ONNX)
add_definitions(-DUSE_ONNX)
endif()
# gtest
if(WITH_TESTING)
include(gtest) # download, build, install gtest
endif()
# fastdeploy
include(fastdeploy)
if(WITH_ASR)
# openfst
include(openfst)
add_dependencies(openfst gflags extern_glog)
endif()
###############################################################################
# Find Package
###############################################################################
# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207
find_package(Threads REQUIRED)
if(WITH_ASR)
# https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3
find_package(Python3 COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG)
if(Python3_FOUND)
message(STATUS "Python3_FOUND = ${Python3_FOUND}")
message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}")
message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}")
message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}")
set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE)
set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE)
endif()
message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}")
include_directories(${PYTHON_INCLUDE_DIR})
if(pybind11_FOUND)
message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}")
message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}")
message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
endif()
# paddle libpaddle.so
# paddle include and link option
# -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so
set(EXECUTE_COMMAND "import os"
"import paddle"
"include_dir = paddle.sysconfig.get_include()"
"paddle_dir=os.path.split(include_dir)[0]"
"libs_dir=os.path.join(paddle_dir, 'libs')"
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
"out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir])"
"out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out)"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_LINK_FLAGS
RESULT_VARIABLE SUCESS)
message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
# paddle compile option
# -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include
set(EXECUTE_COMMAND "import paddle"
"include_dir = paddle.sysconfig.get_include()"
"print(f\"-I{include_dir}\")"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
# for LD_LIBRARY_PATH
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
set(EXECUTE_COMMAND "import os"
"import paddle"
"include_dir=paddle.sysconfig.get_include()"
"paddle_dir=os.path.split(include_dir)[0]"
"libs_dir=os.path.join(paddle_dir, 'libs')"
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
"out=':'.join([libs_dir, fluid_dir]); print(out)"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_LIB_DIRS)
message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
endif()
include(summary)
###############################################################################
# Add local library
###############################################################################
set(ENGINE_ROOT ${CMAKE_SOURCE_DIR}/engine)
add_subdirectory(engine)
###############################################################################
# CPack library
###############################################################################
# build a CPack driven installer package
include (InstallRequiredSystemLibraries)
set(CPACK_PACKAGE_NAME "paddlespeech_library")
set(CPACK_PACKAGE_VENDOR "paddlespeech")
set(CPACK_PACKAGE_VERSION_MAJOR 1)
set(CPACK_PACKAGE_VERSION_MINOR 0)
set(CPACK_PACKAGE_VERSION_PATCH 0)
set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library")
set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com")
set(CPACK_SOURCE_GENERATOR "TGZ")
include (CPack)

@ -1,4 +1,3 @@
# SpeechX -- All in One Speech Task Inference
## Environment
@ -9,7 +8,7 @@ We develop under:
* gcc/g++/gfortran - 8.2.0
* cmake - 3.16.0
> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx.
> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build engine.
> We make sure all things work fun under docker, and recommend using it to develop and deploy.
@ -33,7 +32,7 @@ docker run --privileged --net=host --ipc=host -it --rm -v /path/to/paddlespeech
bash tools/venv.sh
```
2. Build `speechx` and `examples`.
2. Build `engine` and `examples`.
For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version.
For example:
@ -113,3 +112,11 @@ apt-get install gfortran-8
4. `Undefined reference to '_gfortran_concat_string'`
using gcc 8.2, gfortran 8.2.
5. `./boost/python/detail/wrap_python.hpp:57:11: fatal error: pyconfig.h: No such file or directory`
```
apt-get install python3-dev
```
for more info please see [here](https://github.com/okfn/piati/issues/65).

@ -0,0 +1,33 @@
#!/usr/bin/env bash
set -xe
BUILD_ROOT=build/Linux
BUILD_DIR=${BUILD_ROOT}/x86_64
mkdir -p ${BUILD_DIR}
BUILD_TYPE=Release
#BUILD_TYPE=Debug
BUILD_SO=OFF
BUILD_ONNX=ON
BUILD_ASR=ON
BUILD_CLS=ON
BUILD_VAD=ON
PPS_DEBUG=OFF
FASTDEPLOY_INSTALL_DIR=""
# the build script had verified in the paddlepaddle docker image.
# please follow the instruction below to install PaddlePaddle image.
# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
#cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install
cmake -B ${BUILD_DIR} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DBUILD_SHARED_LIBS=${BUILD_SO} \
-DWITH_ONNX=${BUILD_ONNX} \
-DWITH_ASR=${BUILD_ASR} \
-DWITH_CLS=${BUILD_CLS} \
-DWITH_VAD=${BUILD_VAD} \
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
-DWITH_PPS_DEBUG=${PPS_DEBUG}
cmake --build ${BUILD_DIR} -j

@ -0,0 +1,39 @@
#!/bin/bash
set -ex
ANDROID_NDK=/mnt/masimeng/workspace/software/android-ndk-r25b/
# Setting up Android toolchanin
ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a'
ANDROID_PLATFORM="android-21" # API >= 21
ANDROID_STL=c++_shared # 'c++_shared', 'c++_static'
ANDROID_TOOLCHAIN=clang # 'clang' only
TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
# Create build directory
BUILD_ROOT=build/Android
BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21
FASTDEPLOY_INSTALL_DIR="/mnt/masimeng/workspace/FastDeploy/build/Android/arm64-v8a-api-21/install"
mkdir -p ${BUILD_DIR}
cd ${BUILD_DIR}
# CMake configuration with Android toolchain
cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DANDROID_ABI=${ANDROID_ABI} \
-DANDROID_NDK=${ANDROID_NDK} \
-DANDROID_PLATFORM=${ANDROID_PLATFORM} \
-DANDROID_STL=${ANDROID_STL} \
-DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \
-DBUILD_SHARED_LIBS=OFF \
-DWITH_ASR=OFF \
-DWITH_CLS=OFF \
-DWITH_VAD=ON \
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
-DCMAKE_FIND_DEBUG_MODE=OFF \
-Wno-dev ../../..
# Build FastDeploy Android C++ SDK
make

@ -0,0 +1,91 @@
# https://www.jianshu.com/p/33672fb819f5
PATH="/Applications/CMake.app/Contents/bin":"$PATH"
tools_dir=$1
ios_toolchain_cmake=${tools_dir}/"/ios-cmake-4.2.0/ios.toolchain.cmake"
fastdeploy_dir=${tools_dir}"/fastdeploy-ort-mac-build/"
build_targets=("OS64")
build_type_array=("Release")
#static_name="libocr"
#lib_name="libocr"
# Switch to workpath
current_path=`cd $(dirname $0);pwd`
work_path=${current_path}/
build_path=${current_path}/build/
output_path=${current_path}/output/
cd ${work_path}
# Clean
rm -rf ${build_path}
rm -rf ${output_path}
if [ "$1"x = "clean"x ]; then
exit 0
fi
# Build Every Target
for target in "${build_targets[@]}"
do
for build_type in "${build_type_array[@]}"
do
echo -e "\033[1;36;40mBuilding ${build_type} ${target} ... \033[0m"
target_build_path=${build_path}/${target}/${build_type}/
mkdir -p ${target_build_path}
cd ${target_build_path}
if [ $? -ne 0 ];then
echo -e "\033[1;31;40mcd ${target_build_path} failed \033[0m"
exit -1
fi
if [ ${target} == "OS64" ];then
fastdeploy_install_dir=${fastdeploy_dir}/arm64
else
fastdeploy_install_dir=""
echo "fastdeploy_install_dir is null"
exit -1
fi
cmake -DCMAKE_TOOLCHAIN_FILE=${ios_toolchain_cmake} \
-DBUILD_IN_MACOS=ON \
-DBUILD_SHARED_LIBS=OFF \
-DWITH_ASR=OFF \
-DWITH_CLS=OFF \
-DWITH_VAD=ON \
-DFASTDEPLOY_INSTALL_DIR=${fastdeploy_install_dir} \
-DPLATFORM=${target} ../../../
cmake --build . --config ${build_type}
mkdir output
cp engine/vad/interface/libpps_vad_interface.a output
cp engine/vad/interface/vad_interface_main.app/vad_interface_main output
cp ${fastdeploy_install_dir}/lib/libfastdeploy.dylib output
cp ${fastdeploy_install_dir}/third_libs/install/onnxruntime/lib/libonnxruntime.dylib output
done
done
## combine all ios libraries
#DEVROOT=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/
#LIPO_TOOL=${DEVROOT}/usr/bin/lipo
#LIBRARY_PATH=${build_path}
#LIBRARY_OUTPUT_PATH=${output_path}/IOS
#mkdir -p ${LIBRARY_OUTPUT_PATH}
#
#${LIPO_TOOL} \
# -arch i386 ${LIBRARY_PATH}/ios_x86/Release/${lib_name}.a \
# -arch x86_64 ${LIBRARY_PATH}/ios_x86_64/Release/${lib_name}.a \
# -arch armv7 ${LIBRARY_PATH}/ios_armv7/Release/${lib_name}.a \
# -arch armv7s ${LIBRARY_PATH}/ios_armv7s/Release/${lib_name}.a \
# -arch arm64 ${LIBRARY_PATH}/ios_armv8/Release/${lib_name}.a \
# -output ${LIBRARY_OUTPUT_PATH}/${lib_name}.a -create
#
#cp ${work_path}/lib/houyi/lib/ios/libhouyi_score.a ${LIBRARY_OUTPUT_PATH}/
#cp ${work_path}/interface/ocr-interface.h ${output_path}
#cp ${work_path}/version/release.v ${output_path}
#
#echo -e "\033[1;36;40mBuild All Target Success At:\n${output_path}\033[0m"
#exit 0

@ -0,0 +1 @@
cmake_policy(SET CMP0077 NEW)

@ -0,0 +1,116 @@
include(FetchContent)
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD 1 # Wrap download in script to log output
LOG_UPDATE 1 # Wrap update in script to log output
LOG_PATCH 1
LOG_CONFIGURE 1# Wrap configure in script to log output
LOG_BUILD 1 # Wrap build in script to log output
LOG_INSTALL 1
LOG_TEST 1 # Wrap test in script to log output
LOG_MERGED_STDOUTERR 1
LOG_OUTPUT_ON_FAILURE 1
)
if(NOT FASTDEPLOY_INSTALL_DIR)
if(ANDROID)
FetchContent_Declare(
fastdeploy
URL https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz
URL_HASH MD5=2a15301158e9eb157a4f11283689e7ba
${EXTERNAL_PROJECT_LOG_ARGS}
)
add_definitions("-DUSE_PADDLE_LITE_BAKEND")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
else() # Linux
FetchContent_Declare(
fastdeploy
URL https://paddlespeech.bj.bcebos.com/speechx/fastdeploy/fastdeploy-1.0.5-x86_64-onnx.tar.gz
URL_HASH MD5=33900d986ea71aa78635e52f0733227c
${EXTERNAL_PROJECT_LOG_ARGS}
)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3")
endif()
FetchContent_MakeAvailable(fastdeploy)
set(FASTDEPLOY_INSTALL_DIR ${fc_patch}/fastdeploy-src)
endif()
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# fix compiler flags conflict, since fastdeploy using c++11 for project
# this line must after `include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)`
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
include_directories(${FASTDEPLOY_INCS})
# install fastdeploy and dependents lib
# install_fastdeploy_libraries(${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
# No dynamic libs need to install while using
# FastDeploy static lib.
if(ANDROID AND WITH_ANDROID_STATIC_LIB)
return()
endif()
set(DYN_LIB_SUFFIX "*.so*")
if(WIN32)
set(DYN_LIB_SUFFIX "*.dll")
elseif(APPLE)
set(DYN_LIB_SUFFIX "*.dylib*")
endif()
if(FastDeploy_DIR)
set(DYN_SEARCH_DIR ${FastDeploy_DIR})
elseif(FASTDEPLOY_INSTALL_DIR)
set(DYN_SEARCH_DIR ${FASTDEPLOY_INSTALL_DIR})
else()
message(FATAL_ERROR "Please set FastDeploy_DIR/FASTDEPLOY_INSTALL_DIR before call install_fastdeploy_libraries.")
endif()
file(GLOB_RECURSE ALL_NEED_DYN_LIBS ${DYN_SEARCH_DIR}/lib/${DYN_LIB_SUFFIX})
file(GLOB_RECURSE ALL_DEPS_DYN_LIBS ${DYN_SEARCH_DIR}/third_libs/${DYN_LIB_SUFFIX})
if(ENABLE_VISION)
# OpenCV
if(ANDROID)
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${DYN_LIB_SUFFIX})
else()
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_DIR}/../../${DYN_LIB_SUFFIX})
endif()
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_OPENCV_DYN_LIBS})
if(WIN32)
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/x64/vc15/bin/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
elseif(ANDROID AND (NOT WITH_ANDROID_OPENCV_STATIC))
file(GLOB OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${ANDROID_ABI}/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
else() # linux/mac
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/lib/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
endif()
# FlyCV
if(ENABLE_FLYCV)
file(GLOB_RECURSE ALL_FLYCV_DYN_LIBS ${FLYCV_LIB_DIR}/${DYN_LIB_SUFFIX})
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_FLYCV_DYN_LIBS})
if(ANDROID AND (NOT WITH_ANDROID_FLYCV_STATIC))
install(FILES ${ALL_FLYCV_DYN_LIBS} DESTINATION lib)
endif()
endif()
endif()
if(ENABLE_OPENVINO_BACKEND)
# need plugins.xml for openvino backend
set(OPENVINO_RUNTIME_BIN_DIR ${OPENVINO_DIR}/bin)
file(GLOB OPENVINO_PLUGIN_XML ${OPENVINO_RUNTIME_BIN_DIR}/*.xml)
install(FILES ${OPENVINO_PLUGIN_XML} DESTINATION lib)
endif()
# Install other libraries
install(FILES ${ALL_NEED_DYN_LIBS} DESTINATION lib)
install(FILES ${ALL_DEPS_DYN_LIBS} DESTINATION lib)

@ -0,0 +1,14 @@
include(FetchContent)
FetchContent_Declare(
gflags
URL https://paddleaudio.bj.bcebos.com/build/gflag-2.2.2.zip
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
)
FetchContent_MakeAvailable(gflags)
# openfst need
include_directories(${gflags_BINARY_DIR}/include)
link_directories(${gflags_BINARY_DIR})
#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib)

@ -0,0 +1,35 @@
include(FetchContent)
if(ANDROID)
else() # UNIX
add_definitions(-DWITH_GLOG)
FetchContent_Declare(
glog
URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DWITH_GFLAGS=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
)
set(BUILD_TESTING OFF)
FetchContent_MakeAvailable(glog)
include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
endif()
if(ANDROID)
add_library(extern_glog INTERFACE)
add_dependencies(extern_glog gflags)
else() # UNIX
add_library(extern_glog ALIAS glog)
add_dependencies(glog gflags)
endif()

@ -0,0 +1,27 @@
include(FetchContent)
if(ANDROID)
else() # UNIX
FetchContent_Declare(
gtest
URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip
URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
)
FetchContent_MakeAvailable(gtest)
include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
endif()
if(ANDROID)
add_library(extern_gtest INTERFACE)
else() # UNIX
add_dependencies(gtest gflags extern_glog)
add_library(extern_gtest ALIAS gtest)
endif()
if(WITH_TESTING)
enable_testing()
endif()

@ -1,8 +1,8 @@
include(FetchContent)
set(openfst_PREFIX_DIR ${fc_patch}/openfst)
set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
include(FetchContent)
# openfst Acknowledgments:
#Cyril Allauzen, Michael Riley, Johan Schalkwyk, Wojciech Skut and Mehryar Mohri,
#"OpenFst: A General and Efficient Weighted Finite-State Transducer Library",
@ -10,18 +10,33 @@ set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
#Application of Automata, (CIAA 2007), volume 4783 of Lecture Notes in
#Computer Science, pages 11-23. Springer, 2007. http://www.openfst.org.
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD 1 # Wrap download in script to log output
LOG_UPDATE 1 # Wrap update in script to log output
LOG_CONFIGURE 1# Wrap configure in script to log output
LOG_BUILD 1 # Wrap build in script to log output
LOG_TEST 1 # Wrap test in script to log output
LOG_INSTALL 1 # Wrap install in script to log output
)
ExternalProject_Add(openfst
URL https://paddleaudio.bj.bcebos.com/build/openfst_1.7.2.zip
URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${openfst_PREFIX_DIR}
SOURCE_DIR ${openfst_SOURCE_DIR}
BINARY_DIR ${openfst_BINARY_DIR}
BUILD_ALWAYS 0
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread"
"LIBS=-lgflags_nothreads -lglog -lpthread -fPIC"
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
BUILD_COMMAND make -j 4
)
link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include)
message(STATUS "OpenFST inc dir: ${openfst_PREFIX_DIR}/include")
message(STATUS "OpenFST lib dir: ${openfst_PREFIX_DIR}/lib")

@ -0,0 +1,42 @@
#the pybind11 is from:https://github.com/pybind/pybind11
# Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>, All rights reserved.
SET(PYBIND_ZIP "v2.10.0.zip")
SET(LOCAL_PYBIND_ZIP ${FETCHCONTENT_BASE_DIR}/${PYBIND_ZIP})
SET(PYBIND_SRC ${FETCHCONTENT_BASE_DIR}/pybind11)
SET(DOWNLOAD_URL "https://paddleaudio.bj.bcebos.com/build/v2.10.0.zip")
SET(PYBIND_TIMEOUT 600 CACHE STRING "Timeout in seconds when downloading pybind.")
IF(NOT EXISTS ${LOCAL_PYBIND_ZIP})
FILE(DOWNLOAD ${DOWNLOAD_URL}
${LOCAL_PYBIND_ZIP}
TIMEOUT ${PYBIND_TIMEOUT}
STATUS ERR
SHOW_PROGRESS
)
IF(ERR EQUAL 0)
MESSAGE(STATUS "download pybind success")
ELSE()
MESSAGE(FATAL_ERROR "download pybind fail")
ENDIF()
ENDIF()
IF(NOT EXISTS ${PYBIND_SRC})
EXECUTE_PROCESS(
COMMAND ${CMAKE_COMMAND} -E tar xfz ${LOCAL_PYBIND_ZIP}
WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR}
RESULT_VARIABLE tar_result
)
file(RENAME ${FETCHCONTENT_BASE_DIR}/pybind11-2.10.0 ${PYBIND_SRC})
IF (tar_result MATCHES 0)
MESSAGE(STATUS "unzip pybind success")
ELSE()
MESSAGE(FATAL_ERROR "unzip pybind fail")
ENDIF()
ENDIF()
include_directories(${PYBIND_SRC}/include)

@ -0,0 +1,64 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(pps_summary)
message(STATUS "")
message(STATUS "*************PaddleSpeech Building Summary**********")
message(STATUS " PPS_VERSION : ${PPS_VERSION}")
message(STATUS " CMake version : ${CMAKE_VERSION}")
message(STATUS " CMake command : ${CMAKE_COMMAND}")
message(STATUS " UNIX : ${UNIX}")
message(STATUS " ANDROID : ${ANDROID}")
message(STATUS " System : ${CMAKE_SYSTEM_NAME}")
message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}")
message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}")
message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}")
message(STATUS " Build type : ${CMAKE_BUILD_TYPE}")
message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}")
get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS)
message(STATUS " Compile definitions : ${tmp}")
message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}")
message(STATUS " CMAKE_CURRENT_BINARY_DIR : ${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}")
message(STATUS " CMAKE_INSTALL_LIBDIR : ${CMAKE_INSTALL_LIBDIR}")
message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}")
message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}")
message(STATUS "")
message(STATUS " WITH_ASR : ${WITH_ASR}")
message(STATUS " WITH_CLS : ${WITH_CLS}")
message(STATUS " WITH_VAD : ${WITH_VAD}")
message(STATUS " WITH_GPU : ${WITH_GPU}")
message(STATUS " WITH_TESTING : ${WITH_TESTING}")
message(STATUS " WITH_PROFILING : ${WITH_PROFILING}")
message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}")
message(STATUS " FASTDEPLOY_INCS : ${FASTDEPLOY_INCS}")
message(STATUS " FASTDEPLOY_LIBS : ${FASTDEPLOY_LIBS}")
if(WITH_GPU)
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
endif()
if(ANDROID)
message(STATUS " ANDROID_ABI : ${ANDROID_ABI}")
message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}")
message(STATUS " ANDROID_NDK : ${ANDROID_NDK}")
message(STATUS " ANDROID_NDK_VERSION : ${CMAKE_ANDROID_NDK_VERSION}")
endif()
if (WITH_ASR)
message(STATUS " Python executable : ${PYTHON_EXECUTABLE}")
message(STATUS " Python includes : ${PYTHON_INCLUDE_DIR}")
endif()
endfunction()
pps_summary()

@ -0,0 +1,106 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Detects the OS and sets appropriate variables.
# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is
# building for, but the host processor name like centos is necessary
# in some scenes to distinguish system for customization.
#
# for instance, protobuf libs path is <install_dir>/lib64
# on CentOS, but <install_dir>/lib on other systems.
if(UNIX AND NOT APPLE)
# except apple from nix*Os family
set(LINUX TRUE)
endif()
if(WIN32)
set(HOST_SYSTEM "win32")
else()
if(APPLE)
set(HOST_SYSTEM "macosx")
exec_program(
sw_vers ARGS
-productVersion
OUTPUT_VARIABLE HOST_SYSTEM_VERSION)
string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}")
if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET})
# Set cache variable - end user may change this during ccmake or cmake-gui configure.
set(CMAKE_OSX_DEPLOYMENT_TARGET
${MACOS_VERSION}
CACHE
STRING
"Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value."
)
endif()
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
else()
if(EXISTS "/etc/issue")
file(READ "/etc/issue" LINUX_ISSUE)
if(LINUX_ISSUE MATCHES "CentOS")
set(HOST_SYSTEM "centos")
elseif(LINUX_ISSUE MATCHES "Debian")
set(HOST_SYSTEM "debian")
elseif(LINUX_ISSUE MATCHES "Ubuntu")
set(HOST_SYSTEM "ubuntu")
elseif(LINUX_ISSUE MATCHES "Red Hat")
set(HOST_SYSTEM "redhat")
elseif(LINUX_ISSUE MATCHES "Fedora")
set(HOST_SYSTEM "fedora")
endif()
string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION
"${LINUX_ISSUE}")
endif()
if(EXISTS "/etc/redhat-release")
file(READ "/etc/redhat-release" LINUX_ISSUE)
if(LINUX_ISSUE MATCHES "CentOS")
set(HOST_SYSTEM "centos")
endif()
endif()
if(NOT HOST_SYSTEM)
set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME})
endif()
endif()
endif()
# query number of logical cores
cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES)
mark_as_advanced(HOST_SYSTEM CPU_CORES)
message(
STATUS
"Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}")
message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores")
# external dependencies log output
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD
0 # Wrap download in script to log output
LOG_UPDATE
1 # Wrap update in script to log output
LOG_CONFIGURE
1 # Wrap configure in script to log output
LOG_BUILD
0 # Wrap build in script to log output
LOG_TEST
1 # Wrap test in script to log output
LOG_INSTALL
0 # Wrap install in script to log output
)

@ -0,0 +1,22 @@
project(speechx LANGUAGES CXX)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common)
add_subdirectory(kaldi)
add_subdirectory(common)
if(WITH_ASR)
add_subdirectory(asr)
endif()
if(WITH_CLS)
add_subdirectory(audio_classification)
endif()
if(WITH_VAD)
add_subdirectory(vad)
endif()
add_subdirectory(codelab)

@ -0,0 +1,11 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(ASR LANGUAGES CXX)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/server)
add_subdirectory(decoder)
add_subdirectory(recognizer)
add_subdirectory(nnet)
add_subdirectory(server)

@ -0,0 +1,24 @@
set(srcs)
list(APPEND srcs
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
)
add_library(decoder STATIC ${srcs})
target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
# test
set(TEST_BINS
ctc_prefix_beam_search_decoder_main
ctc_tlg_decoder_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
endforeach()

@ -22,51 +22,22 @@ namespace ppspeech {
struct CTCBeamSearchOptions {
// common
int blank;
// ds2
std::string dict_file;
std::string lm_path;
int beam_size;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int cutoff_top_n;
int num_proc_bsearch;
std::string word_symbol_table;
// u2
int first_beam_size;
int second_beam_size;
CTCBeamSearchOptions()
: blank(0),
dict_file("vocab.txt"),
lm_path(""),
beam_size(300),
alpha(1.9f),
beta(5.0),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10),
word_symbol_table("vocab.txt"),
first_beam_size(10),
second_beam_size(10) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "Ds2BeamSearchConfig: ";
opts->Register("dict", &dict_file, module + "vocab file path.");
opts->Register(
"lm-path", &lm_path, module + "ngram language model path.");
opts->Register("alpha", &alpha, module + "alpha");
opts->Register("beta", &beta, module + "beta");
opts->Register("beam-size",
&beam_size,
module + "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
std::string module = "CTCBeamSearchOptions: ";
opts->Register("word_symbol_table", &word_symbol_table, module + "vocab file path.");
opts->Register("blank", &blank, "blank id, default is 0.");
module = "U2BeamSearchConfig: ";
opts->Register(
"first-beam-size", &first_beam_size, module + "first beam size.");
opts->Register("second-beam-size",

@ -17,13 +17,12 @@
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "absl/strings/str_join.h"
#include "base/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "utils/math.h"
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
#include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType;
@ -31,11 +30,10 @@ using paddle::platform::TracerEventType;
namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts)
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(vocab_path));
fst::SymbolTable::ReadText(opts.word_symbol_table));
CHECK(unit_table_ != nullptr);
Reset();
@ -66,7 +64,6 @@ void CTCPrefixBeamSearch::Reset() {
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
double search_cost = 0.0;
@ -78,21 +75,21 @@ void CTCPrefixBeamSearch::AdvanceDecode(
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
feat_nnet_cost += timer.Elapsed();
if (flag == false) {
VLOG(3) << "decoder advance decode exit." << frame_prob.size();
VLOG(2) << "decoder advance decode exit." << frame_prob.size();
break;
}
timer.Reset();
std::vector<std::vector<kaldi::BaseFloat>> likelihood;
likelihood.push_back(frame_prob);
likelihood.push_back(std::move(frame_prob));
AdvanceDecoding(likelihood);
search_cost += timer.Elapsed();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_;
}
VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
VLOG(2) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
<< " sec.";
VLOG(1) << "AdvanceDecode search cost: " << search_cost << " sec.";
VLOG(2) << "AdvanceDecode search cost: " << search_cost << " sec.";
}
static bool PrefixScoreCompare(
@ -105,7 +102,7 @@ static bool PrefixScoreCompare(
void CTCPrefixBeamSearch::AdvanceDecoding(
const std::vector<std::vector<kaldi::BaseFloat>>& logp) {
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding",
TracerEventType::UserDefined,
1);

@ -27,8 +27,7 @@ namespace ppspeech {
class ContextGraph;
class CTCPrefixBeamSearch : public DecoderBase {
public:
CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts);
CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {}
SearchType Type() const { return SearchType::kPrefixBeamSearch; }
@ -45,7 +44,7 @@ class CTCPrefixBeamSearch : public DecoderBase {
void FinalizeSearch();
const std::shared_ptr<fst::SymbolTable> VocabTable() const {
const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return unit_table_;
}
@ -57,7 +56,6 @@ class CTCPrefixBeamSearch : public DecoderBase {
}
const std::vector<std::vector<int>>& Times() const { return times_; }
protected:
std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override;

@ -12,18 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/str_split.h"
#include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "frontend/data_cache.h"
#include "fst/symbol-table.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
#include "nnet/u2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(vocab_path, "", "vocab path");
DEFINE_string(word_symbol_table, "", "vocab path");
DEFINE_string(model_path, "", "paddle nnet model");
@ -40,7 +40,7 @@ using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test ds2 online decoder by feeding speech feature
// test u2 online decoder by feeding speech feature
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -52,10 +52,10 @@ int main(int argc, char* argv[]) {
CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, "");
CHECK_NE(FLAGS_vocab_path, "");
CHECK_NE(FLAGS_word_symbol_table, "");
CHECK_NE(FLAGS_model_path, "");
LOG(INFO) << "model path: " << FLAGS_model_path;
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
LOG(INFO) << "Reading vocab table " << FLAGS_word_symbol_table;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
@ -70,15 +70,18 @@ int main(int argc, char* argv[]) {
// decodeable
std::shared_ptr<ppspeech::DataCache> raw_data =
std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nnet, raw_data, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet, raw_data);
std::make_shared<ppspeech::Decodable>(nnet_producer);
// decoder
ppspeech::CTCBeamSearchOptions opts;
opts.blank = 0;
opts.first_beam_size = 10;
opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts);
opts.word_symbol_table = FLAGS_word_symbol_table;
ppspeech::CTCPrefixBeamSearch decoder(opts);
int32 chunk_size = FLAGS_receptive_field_length +
@ -122,15 +125,14 @@ int main(int argc, char* argv[]) {
}
kaldi::Vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
std::vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim);
feature_chunk_row.CopyFromVec(feat_row);
std::memcpy(feature_chunk.data() + row_id * feat_dim,
feat_row.Data(),
feat_dim * sizeof(kaldi::BaseFloat));
++start;
}

@ -13,12 +13,14 @@
// limitations under the License.
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) {
fst_ = opts.fst_ptr;
CHECK(fst_ != nullptr);
CHECK(!opts.word_symbol_table.empty());
word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table));
@ -29,6 +31,11 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
void TLGDecoder::Reset() {
decoder_->InitDecoding();
hypotheses_.clear();
likelihood_.clear();
olabels_.clear();
times_.clear();
num_frame_decoded_ = 0;
return;
}
@ -68,14 +75,52 @@ std::string TLGDecoder::GetPartialResult() {
return words;
}
void TLGDecoder::FinalizeSearch() {
decoder_->FinalizeDecoding();
kaldi::CompactLattice clat;
decoder_->GetLattice(&clat, true);
kaldi::Lattice lat, nbest_lat;
fst::ConvertLattice(clat, &lat);
fst::ShortestPath(lat, &nbest_lat, opts_.nbest);
std::vector<kaldi::Lattice> nbest_lats;
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
hypotheses_.clear();
hypotheses_.reserve(nbest_lats.size());
likelihood_.clear();
likelihood_.reserve(nbest_lats.size());
times_.clear();
times_.reserve(nbest_lats.size());
for (auto lat : nbest_lats) {
kaldi::LatticeWeight weight;
std::vector<int> hypothese;
std::vector<int> time;
std::vector<int> alignment;
std::vector<int> words_id;
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
int idx = 0;
for (; idx < alignment.size() - 1; ++idx) {
if (alignment[idx] == 0) continue;
if (alignment[idx] != alignment[idx + 1]) {
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
}
}
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
hypotheses_.push_back(hypothese);
times_.push_back(time);
olabels_.push_back(words_id);
likelihood_.push_back(-(weight.Value2() + weight.Value1()));
}
}
std::string TLGDecoder::GetFinalBestPath() {
if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;

@ -18,13 +18,14 @@
#include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
#include "utils/file_utils.h"
DECLARE_string(graph_path);
DECLARE_string(word_symbol_table);
DECLARE_string(graph_path);
DECLARE_int32(max_active);
DECLARE_double(beam);
DECLARE_double(lattice_beam);
DECLARE_int32(nbest);
namespace ppspeech {
@ -33,17 +34,27 @@ struct TLGDecoderOptions {
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_ptr;
int nbest;
TLGDecoderOptions() : word_symbol_table(""), fst_path(""), fst_ptr(nullptr), nbest(10) {}
static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table;
if (!decoder_opts.fst_path.empty()) {
CHECK(FileExists(decoder_opts.fst_path));
decoder_opts.fst_ptr.reset(fst::Fst<fst::StdArc>::Read(FLAGS_graph_path));
}
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
decoder_opts.nbest = FLAGS_nbest;
LOG(INFO) << "LatticeFasterDecoder max active: "
<< decoder_opts.opts.max_active;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam;
@ -59,20 +70,38 @@ class TLGDecoder : public DecoderBase {
explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default;
void InitDecoder();
void Reset();
void InitDecoder() override;
void Reset() override;
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;
void Decode();
std::string GetFinalBestPath() override;
std::string GetPartialResult() override;
const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return word_symbol_table_;
}
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
const std::vector<std::string>& nbest_words);
void FinalizeSearch() override;
const std::vector<std::vector<int>>& Inputs() const override {
return hypotheses_;
}
const std::vector<std::vector<int>>& Outputs() const override {
return olabels_;
} // outputs_; }
const std::vector<float>& Likelihood() const override {
return likelihood_;
}
const std::vector<std::vector<int>>& Times() const override {
return times_;
}
protected:
std::string GetBestPath() override {
CHECK(false);
@ -90,9 +119,16 @@ class TLGDecoder : public DecoderBase {
private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
int num_frame_decoded_;
std::vector<std::vector<int>> hypotheses_;
std::vector<std::vector<int>> olabels_;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
TLGDecoderOptions opts_;
};

@ -14,21 +14,24 @@
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/param.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding nnet posterior probability
// test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -36,41 +39,51 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
FLAGS_nnet_prob_respecifier);
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader(
FLAGS_nnet_prob_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::TLGDecoderOptions opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nullptr, nullptr, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nullptr, nullptr));
new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale));
decoder.InitDecoder();
kaldi::Timer timer;
for (; !likelihood_reader.Done(); likelihood_reader.Next()) {
string utt = likelihood_reader.Key();
const kaldi::Matrix<BaseFloat> likelihood = likelihood_reader.Value();
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << likelihood.NumRows();
LOG(INFO) << "cols: " << likelihood.NumCols();
decodable->Acceptlikelihood(likelihood);
for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) {
string utt = nnet_prob_reader.Key();
kaldi::Matrix<BaseFloat> prob = nnet_prob_reader.Value();
decodable->Acceptlikelihood(prob);
decoder.AdvanceDecode(decodable);
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);

@ -1,4 +1,3 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
@ -16,6 +15,7 @@
#pragma once
#include "base/common.h"
#include "fst/symbol-table.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
@ -41,6 +41,14 @@ class DecoderInterface {
virtual std::string GetPartialResult() = 0;
virtual const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const = 0;
virtual void FinalizeSearch() = 0;
virtual const std::vector<std::vector<int>>& Inputs() const = 0;
virtual const std::vector<std::vector<int>>& Outputs() const = 0;
virtual const std::vector<float>& Likelihood() const = 0;
virtual const std::vector<std::vector<int>>& Times() const = 0;
protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;

@ -15,8 +15,6 @@
#pragma once
#include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
// feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
@ -37,36 +35,22 @@ DEFINE_int32(subsampling_rate,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet
DEFINE_string(vocab_path, "", "nnet vocab path.");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
#ifdef USE_ONNX
DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path");
#endif
// decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "", "decoder graph");
DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_double(blank_threshold, 0.98, "blank skip threshold");
// DecodeOptions flags
// DEFINE_int32(chunk_size, -1, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight,
0.5,

@ -0,0 +1,21 @@
set(srcs decodable.cc nnet_producer.cc)
list(APPEND srcs u2_nnet.cc)
if(WITH_ONNX)
list(APPEND srcs u2_onnx_nnet.cc)
endif()
add_library(nnet STATIC ${srcs})
target_link_libraries(nnet utils)
if(WITH_ONNX)
target_link_libraries(nnet ${FASTDEPLOY_LIBS})
endif()
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
# test bin
#set(bin_name u2_nnet_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})

@ -21,29 +21,25 @@ using kaldi::Matrix;
using kaldi::Vector;
using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend,
Decodable::Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
kaldi::BaseFloat acoustic_scale)
: frontend_(frontend),
nnet_(nnet),
: nnet_producer_(nnet_producer),
frame_offset_(0),
frames_ready_(0),
acoustic_scale_(acoustic_scale) {}
// for debug
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_out_cache_ = likelihood;
frames_ready_ += likelihood.NumRows();
nnet_producer_->Acceptlikelihood(likelihood);
}
// return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) {
bool flag = EnsureFrameHaveComputed(frame);
EnsureFrameHaveComputed(frame);
return frame >= frames_ready_;
}
@ -64,32 +60,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
bool Decodable::AdvanceChunk() {
kaldi::Timer timer;
// read feats
Vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(3) << "decodable exit;";
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec.";
VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats.";
// forward feats
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
Vector<BaseFloat>& logprobs = out.logprobs;
VLOG(2) << "Forward out " << logprobs.Dim() / vocab_dim
<< " decoder frames.";
// cache nnet outupts
nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim);
nnet_out_cache_.CopyRowsFromVec(logprobs);
// update state, decoding frame.
bool flag = nnet_producer_->Read(&framelikelihood_);
if (flag == false) return false;
frame_offset_ = frames_ready_;
frames_ready_ += nnet_out_cache_.NumRows();
frames_ready_ += 1;
VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed()
<< " sec.";
return true;
@ -101,17 +75,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
return false;
}
int nrows = nnet_out_cache_.NumRows();
CHECK(nrows == (frames_ready_ - frame_offset_));
if (nrows <= 0) {
if (framelikelihood_.empty()) {
LOG(WARNING) << "No new nnet out in cache.";
return false;
}
logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols());
logprobs->CopyRowsFromMat(nnet_out_cache_);
*vocab_dim = nnet_out_cache_.NumCols();
size_t dim = framelikelihood_.size();
logprobs->Resize(framelikelihood_.size());
std::memcpy(logprobs->Data(),
framelikelihood_.data(),
dim * sizeof(kaldi::BaseFloat));
*vocab_dim = framelikelihood_.size();
return true;
}
@ -122,19 +96,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
return false;
}
int nrows = nnet_out_cache_.NumRows();
CHECK(nrows == (frames_ready_ - frame_offset_));
int vocab_size = nnet_out_cache_.NumCols();
likelihood->resize(vocab_size);
for (int32 idx = 0; idx < vocab_size; ++idx) {
(*likelihood)[idx] =
nnet_out_cache_(frame - frame_offset_, idx) * acoustic_scale_;
VLOG(4) << "nnet out: " << frame << " offset:" << frame_offset_ << " "
<< nnet_out_cache_.NumRows()
<< " logprob: " << nnet_out_cache_(frame - frame_offset_, idx);
}
CHECK_EQ(1, (frames_ready_ - frame_offset_));
*likelihood = framelikelihood_;
return true;
}
@ -143,36 +106,30 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return false;
}
CHECK_LE(index, nnet_out_cache_.NumCols());
CHECK_LE(index, framelikelihood_.size());
CHECK_LE(frame, frames_ready_);
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
BaseFloat logprob = 0.0;
int32 frame_idx = frame - frame_offset_;
BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index));
if (nnet_->IsLogProb()) {
logprob = nnet_out;
} else {
logprob = std::log(nnet_out + std::numeric_limits<float>::epsilon());
}
CHECK(!std::isnan(logprob) && !std::isinf(logprob));
CHECK_EQ(frame_idx, 0);
logprob = framelikelihood_[TokenId2NnetId(index)];
return acoustic_scale_ * logprob;
}
void Decodable::Reset() {
if (frontend_ != nullptr) frontend_->Reset();
if (nnet_ != nullptr) nnet_->Reset();
if (nnet_producer_ != nullptr) nnet_producer_->Reset();
frame_offset_ = 0;
frames_ready_ = 0;
nnet_out_cache_.Resize(0, 0);
framelikelihood_.clear();
}
void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
kaldi::Timer timer;
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec.";
}

@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "nnet/nnet_producer.h"
namespace ppspeech {
@ -24,12 +26,9 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend,
explicit Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config);
// nnet logprob output, used by wfst
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface {
void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); }
bool IsInputFinished() const { return nnet_producer_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame);
int32 TokenId2NnetId(int32 token_id);
std::shared_ptr<NnetBase> Nnet() { return nnet_; }
// for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
private:
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
// nnet outputs' cache
kaldi::Matrix<kaldi::BaseFloat> nnet_out_cache_;
std::shared_ptr<NnetProducer> nnet_producer_;
// the frame is nnet prob frame rather than audio feature frame
// nnet frame subsample the feature frame
@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface {
// so use subsampled_frame
int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_;
std::vector<kaldi::BaseFloat> framelikelihood_;
kaldi::BaseFloat acoustic_scale_;
};

@ -15,7 +15,6 @@
#include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
DECLARE_int32(subsampling_rate);
@ -25,26 +24,20 @@ DECLARE_string(model_input_names);
DECLARE_string(model_output_names);
DECLARE_string(model_cache_names);
DECLARE_string(model_cache_shapes);
#ifdef USE_ONNX
DECLARE_bool(with_onnx_model);
#endif
namespace ppspeech {
struct ModelOptions {
// common
int subsample_rate{1};
int thread_num{1}; // predictor thread pool size for ds2;
bool use_gpu{false};
std::string model_path;
std::string param_path;
// ds2 for inference
std::string input_names{};
std::string output_names{};
std::string cache_names{};
std::string cache_shape{};
bool switch_ir_optim{false};
bool enable_fc_padding{false};
bool enable_profile{false};
#ifdef USE_ONNX
bool with_onnx_model{false};
#endif
static ModelOptions InitFromFlags() {
ModelOptions opts;
@ -52,26 +45,17 @@ struct ModelOptions {
LOG(INFO) << "subsampling rate: " << opts.subsample_rate;
opts.model_path = FLAGS_model_path;
LOG(INFO) << "model path: " << opts.model_path;
opts.param_path = FLAGS_param_path;
LOG(INFO) << "param path: " << opts.param_path;
LOG(INFO) << "DS2 param: ";
opts.cache_names = FLAGS_model_cache_names;
LOG(INFO) << " cache names: " << opts.cache_names;
opts.cache_shape = FLAGS_model_cache_shapes;
LOG(INFO) << " cache shape: " << opts.cache_shape;
opts.input_names = FLAGS_model_input_names;
LOG(INFO) << " input names: " << opts.input_names;
opts.output_names = FLAGS_model_output_names;
LOG(INFO) << " output names: " << opts.output_names;
#ifdef USE_ONNX
opts.with_onnx_model = FLAGS_with_onnx_model;
LOG(INFO) << "with onnx model: " << opts.with_onnx_model;
#endif
return opts;
}
};
struct NnetOut {
// nnet out. maybe logprob or prob. Almost time this is logprob.
kaldi::Vector<kaldi::BaseFloat> logprobs;
std::vector<kaldi::BaseFloat> logprobs;
int32 vocab_dim;
// nnet state. Only using in Attention model.
@ -89,7 +73,7 @@ class NnetInterface {
// nnet do not cache feats, feats cached by frontend.
// nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache,
// frame_offset.
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
virtual void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) = 0;
@ -105,14 +89,14 @@ class NnetInterface {
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const = 0;
};
class NnetBase : public NnetInterface {
public:
int SubsamplingRate() const { return subsampling_rate_; }
virtual std::shared_ptr<NnetBase> Clone() const = 0;
protected:
int subsampling_rate_{1};
};

@ -0,0 +1,99 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet_producer.h"
#include "matrix/kaldi-matrix.h"
namespace ppspeech {
using kaldi::BaseFloat;
using std::vector;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold)
: nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) {
Reset();
}
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
}
void NnetProducer::Acceptlikelihood(
const kaldi::Matrix<BaseFloat>& likelihood) {
std::vector<BaseFloat> prob;
prob.resize(likelihood.NumCols());
for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) {
for (size_t col = 0; col < likelihood.NumCols(); ++col) {
prob[col] = likelihood(idx, col);
}
cache_.push_back(prob);
}
}
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
bool flag = cache_.pop(nnet_prob);
return flag;
}
bool NnetProducer::Compute() {
vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
if (frontend_->IsFinished() == true) {
finished_ = true;
}
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
size_t nframes = out.logprobs.size() / vocab_dim;
VLOG(1) << "Forward out " << nframes << " decoder frames.";
for (size_t idx = 0; idx < nframes; ++idx) {
std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
// process blank prob
float blank_prob = std::exp(logprob[0]);
if (blank_prob > blank_threshold_) {
last_frame_logprob_ = logprob;
is_last_frame_skip_ = true;
continue;
} else {
int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin();
if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) {
cache_.push_back(last_frame_logprob_);
last_max_elem_ = cur_max;
}
last_max_elem_ = cur_max;
is_last_frame_skip_ = false;
cache_.push_back(logprob);
}
}
return true;
}
void NnetProducer::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
}
} // namespace ppspeech

@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "base/safe_queue.h"
#include "frontend/frontend_itf.h"
#include "nnet/nnet_itf.h"
namespace ppspeech {
class NnetProducer {
public:
explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold);
// Feed feats or waves
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
void Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood);
// nnet
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
bool Empty() const { return cache_.empty(); }
void SetInputFinished() {
LOG(INFO) << "set finished";
frontend_->SetFinished();
}
// the compute thread exit
bool IsFinished() const {
return (frontend_->IsFinished() && finished_);
}
~NnetProducer() {}
void Reset() {
if (frontend_ != NULL) frontend_->Reset();
if (nnet_ != NULL) nnet_->Reset();
cache_.clear();
finished_ = false;
}
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score);
bool Compute();
private:
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::vector<BaseFloat> last_frame_logprob_;
bool is_last_frame_skip_ = false;
int last_max_elem_ = -1;
float blank_threshold_ = 0.0;
bool finished_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer);
};
} // namespace ppspeech

@ -17,12 +17,13 @@
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc
#include "nnet/u2_nnet.h"
#include <type_traits>
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
#include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType;
#endif // end USE_PROFILING
#endif // end WITH_PROFILING
namespace ppspeech {
@ -30,7 +31,7 @@ namespace ppspeech {
void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
paddle::jit::utils::InitKernelSignatureMap();
#ifdef USE_GPU
#ifdef WITH_GPU
dev_ = phi::GPUPlace();
#else
dev_ = phi::CPUPlace();
@ -62,12 +63,12 @@ void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
}
void U2Nnet::Warmup() {
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event("warmup", TracerEventType::UserDefined, 1);
#endif
{
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event(
"warmup-encoder-ctc", TracerEventType::UserDefined, 1);
#endif
@ -91,7 +92,7 @@ void U2Nnet::Warmup() {
}
{
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event("warmup-decoder", TracerEventType::UserDefined, 1);
#endif
auto hyps =
@ -101,10 +102,10 @@ void U2Nnet::Warmup() {
auto encoder_out = paddle::ones(
{1, 20, 512}, paddle::DataType::FLOAT32, phi::CPUPlace());
std::vector<paddle::experimental::Tensor> inputs{
std::vector<paddle::Tensor> inputs{
hyps, hyps_lens, encoder_out};
std::vector<paddle::experimental::Tensor> outputs =
std::vector<paddle::Tensor> outputs =
forward_attention_decoder_(inputs);
}
@ -118,27 +119,46 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) {
// shallow copy
U2Nnet::U2Nnet(const U2Nnet& other) {
// copy meta
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidecoder_ = other.is_bidecoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
forward_encoder_chunk_ = other.forward_encoder_chunk_;
forward_attention_decoder_ = other.forward_attention_decoder_;
ctc_activation_ = other.ctc_activation_;
offset_ = other.offset_;
// copy model ptr
model_ = other.model_;
// model_ = other.model_->Clone();
// hack, fix later
#ifdef WITH_GPU
dev_ = phi::GPUPlace();
#else
dev_ = phi::CPUPlace();
#endif
paddle::jit::Layer model = paddle::jit::Load(other.opts_.model_path, dev_);
model_ = std::make_shared<paddle::jit::Layer>(std::move(model));
ctc_activation_ = model_->Function("ctc_activation");
subsampling_rate_ = model_->Attribute<int>("subsampling_rate");
right_context_ = model_->Attribute<int>("right_context");
sos_ = model_->Attribute<int>("sos_symbol");
eos_ = model_->Attribute<int>("eos_symbol");
is_bidecoder_ = model_->Attribute<int>("is_bidirectional_decoder");
forward_encoder_chunk_ = model_->Function("forward_encoder_chunk");
forward_attention_decoder_ = model_->Function("forward_attention_decoder");
ctc_activation_ = model_->Function("ctc_activation");
CHECK(forward_encoder_chunk_.IsValid());
CHECK(forward_attention_decoder_.IsValid());
CHECK(ctc_activation_.IsValid());
LOG(INFO) << "Paddle Model Info: ";
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
LOG(INFO) << "\tright context " << right_context_;
LOG(INFO) << "\tsos " << sos_;
LOG(INFO) << "\teos " << eos_;
LOG(INFO) << "\tis bidecoder " << is_bidecoder_ << std::endl;
// ignore inner states
}
std::shared_ptr<NnetBase> U2Nnet::Copy() const {
std::shared_ptr<NnetBase> U2Nnet::Clone() const {
auto asr_model = std::make_shared<U2Nnet>(*this);
// reset inner state for new decoding
asr_model->Reset();
@ -154,6 +174,7 @@ void U2Nnet::Reset() {
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
encoder_outs_.clear();
VLOG(1) << "FeedForward cost: " << cost_time_ << " sec. ";
VLOG(3) << "u2nnet reset";
}
@ -165,23 +186,18 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) {
}
void U2Nnet::FeedForward(const kaldi::Vector<BaseFloat>& features,
void U2Nnet::FeedForward(const std::vector<BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) {
kaldi::Timer timer;
std::vector<kaldi::BaseFloat> chunk_feats(features.Data(),
features.Data() + features.Dim());
std::vector<kaldi::BaseFloat> ctc_probs;
ForwardEncoderChunkImpl(
chunk_feats, feature_dim, &ctc_probs, &out->vocab_dim);
out->logprobs.Resize(ctc_probs.size(), kaldi::kSetZero);
std::memcpy(out->logprobs.Data(),
ctc_probs.data(),
ctc_probs.size() * sizeof(kaldi::BaseFloat));
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
<< chunk_feats.size() / feature_dim << " frames.";
features, feature_dim, &out->logprobs, &out->vocab_dim);
float forward_chunk_time = timer.Elapsed();
VLOG(1) << "FeedForward cost: " << forward_chunk_time << " sec. "
<< features.size() / feature_dim << " frames.";
cost_time_ += forward_chunk_time;
}
@ -190,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) {
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event(
"ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1);
#endif
@ -210,7 +226,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
// not cache feature in nnet
CHECK_EQ(cached_feats_.size(), 0);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
CHECK_EQ((std::is_same<float, kaldi::BaseFloat>::value), true);
std::memcpy(feats_ptr,
chunk_feats.data(),
chunk_feats.size() * sizeof(kaldi::BaseFloat));
@ -218,7 +234,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
VLOG(3) << "feats shape: " << feats.shape()[0] << ", " << feats.shape()[1]
<< ", " << feats.shape()[2];
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("feat", std::ios_base::app | std::ios_base::out);
path << offset_;
@ -237,7 +253,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
#endif
// Endocer chunk forward
#ifdef USE_GPU
#ifdef WITH_GPU
feats = feats.copy_to(paddle::GPUPlace(), /*blocking*/ false);
att_cache_ = att_cache_.copy_to(paddle::GPUPlace()), /*blocking*/ false;
cnn_cache_ = cnn_cache_.copy_to(Paddle::GPUPlace(), /*blocking*/ false);
@ -254,7 +270,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
std::vector<paddle::Tensor> outputs = forward_encoder_chunk_(inputs);
CHECK_EQ(outputs.size(), 3);
#ifdef USE_GPU
#ifdef WITH_GPU
paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace());
att_cache_ = outputs[1].copy_to(paddle::CPUPlace());
cnn_cache_ = outputs[2].copy_to(paddle::CPUPlace());
@ -264,7 +280,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
cnn_cache_ = outputs[2];
#endif
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logits",
std::ios_base::app | std::ios_base::out);
@ -294,7 +310,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
encoder_outs_.push_back(chunk_out);
VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size();
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logits_list",
std::ios_base::app | std::ios_base::out);
@ -313,7 +329,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
}
#endif // end TEST_DEBUG
#ifdef USE_GPU
#ifdef WITH_GPU
#error "Not implementation."
@ -327,7 +343,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
CHECK_EQ(outputs.size(), 1);
paddle::Tensor ctc_log_probs = outputs[0];
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logprob",
std::ios_base::app | std::ios_base::out);
@ -349,7 +365,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
}
#endif // end TEST_DEBUG
#endif // end USE_GPU
#endif // end WITH_GPU
// Copy to output, (B=1,T,D)
std::vector<int64_t> ctc_log_probs_shape = ctc_log_probs.shape();
@ -366,7 +382,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
std::memcpy(
out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat));
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logits_list_ctc",
std::ios_base::app | std::ios_base::out);
@ -415,7 +431,7 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob,
void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1);
#endif
CHECK(rescoring_score != nullptr);
@ -457,7 +473,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
}
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logits_concat",
std::ios_base::app | std::ios_base::out);
@ -481,7 +497,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
paddle::Tensor encoder_out = paddle::concat(encoder_outs_, 1);
VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size();
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_out0",
std::ios_base::app | std::ios_base::out);
@ -500,7 +516,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_out",
std::ios_base::app | std::ios_base::out);
@ -519,7 +535,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
std::vector<paddle::experimental::Tensor> inputs{
std::vector<paddle::Tensor> inputs{
hyps_tensor, hyps_lens, encoder_out};
std::vector<paddle::Tensor> outputs = forward_attention_decoder_(inputs);
CHECK_EQ(outputs.size(), 2);
@ -531,7 +547,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
CHECK_EQ(probs_shape[0], num_hyps);
CHECK_EQ(probs_shape[1], max_hyps_len);
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("decoder_logprob",
std::ios_base::app | std::ios_base::out);
@ -549,7 +565,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("hyps_lens",
std::ios_base::app | std::ios_base::out);
@ -565,7 +581,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("hyps_tensor",
std::ios_base::app | std::ios_base::out);
@ -590,7 +606,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} else {
// dump r_probs
CHECK_EQ(r_probs_shape.size(), 1);
CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
//CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
}
// compute rescoring score
@ -600,15 +616,15 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
VLOG(2) << "split prob: " << probs_v.size() << " "
<< probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0]
<< ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2];
CHECK(static_cast<int>(probs_v.size()) == num_hyps)
<< ": is " << probs_v.size() << " expect: " << num_hyps;
//CHECK(static_cast<int>(probs_v.size()) == num_hyps)
// << ": is " << probs_v.size() << " expect: " << num_hyps;
std::vector<paddle::Tensor> r_probs_v;
if (is_bidecoder_ && reverse_weight > 0) {
r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0);
CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
<< "r_probs_v size: is " << r_probs_v.size()
<< " expect: " << num_hyps;
//CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
// << "r_probs_v size: is " << r_probs_v.size()
// << " expect: " << num_hyps;
}
for (int i = 0; i < num_hyps; ++i) {
@ -638,7 +654,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
void U2Nnet::EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const {
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
// list of (B=1,T,D)
int size = encoder_outs_.size();
VLOG(3) << "encoder_outs_ size: " << size;
@ -650,15 +666,15 @@ void U2Nnet::EncoderOuts(
const int& B = shape[0];
const int& T = shape[1];
const int& D = shape[2];
CHECK(B == 1) << "Only support batch one.";
//CHECK(B == 1) << "Only support batch one.";
VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << ","
<< D << ")";
const float* this_tensor_ptr = item.data<float>();
for (int j = 0; j < T; j++) {
const float* cur = this_tensor_ptr + j * D;
kaldi::Vector<kaldi::BaseFloat> out(D);
std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat));
std::vector<kaldi::BaseFloat> out(D);
std::memcpy(out.data(), cur, D * sizeof(kaldi::BaseFloat));
encoder_out->emplace_back(out);
}
}

@ -18,7 +18,7 @@
#pragma once
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "paddle/extension.h"
#include "paddle/jit/all.h"
@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase {
num_left_chunks_ = num_left_chunks;
}
virtual std::shared_ptr<NnetBase> Copy() const = 0;
virtual std::shared_ptr<NnetBase> Clone() const = 0;
protected:
virtual void ForwardEncoderChunkImpl(
@ -76,7 +76,7 @@ class U2Nnet : public U2NnetBase {
explicit U2Nnet(const ModelOptions& opts);
U2Nnet(const U2Nnet& other);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase {
std::shared_ptr<paddle::jit::Layer> model() const { return model_; }
std::shared_ptr<NnetBase> Copy() const override;
std::shared_ptr<NnetBase> Clone() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
@ -111,10 +111,10 @@ class U2Nnet : public U2NnetBase {
void FeedEncoderOuts(const paddle::Tensor& encoder_out);
void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const;
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
ModelOptions opts_; // hack, fix later
private:
ModelOptions opts_;
phi::Place dev_;
std::shared_ptr<paddle::jit::Layer> model_{nullptr};
@ -127,6 +127,7 @@ class U2Nnet : public U2NnetBase {
paddle::jit::Function forward_encoder_chunk_;
paddle::jit::Function forward_attention_decoder_;
paddle::jit::Function ctc_activation_;
float cost_time_ = 0.0;
};
} // namespace ppspeech

@ -15,8 +15,8 @@
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "frontend/assembler.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"

@ -0,0 +1,145 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef USE_ONNX
#include "nnet/u2_nnet.h"
#else
#include "nnet/u2_onnx_nnet.h"
#endif
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/feature_pipeline.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
#include "nnet/u2_nnet.h"
DEFINE_string(wav_rspecifier, "", "test wav rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int32 num_done = 0, num_err = 0;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
CHECK_GT(FLAGS_model_path.size(), 0);
LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::FeaturePipelineOptions feature_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
feature_opts.assembler_opts.fill_zero = false;
#ifndef USE_ONNX
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
#else
std::shared_ptr<ppspeech::U2OnnxNnet> nnet(new ppspeech::U2OnnxNnet(model_opts));
#endif
std::shared_ptr<ppspeech::FeaturePipeline> feature_pipeline(
new ppspeech::FeaturePipeline(feature_opts));
std::shared_ptr<ppspeech::NnetProducer> nnet_producer(
new ppspeech::NnetProducer(nnet, feature_pipeline));
kaldi::Timer timer;
float tot_wav_duration = 0;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
nnet_producer->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
nnet_producer->SetInputFinished();
}
// no overlap
sample_offset += cur_chunk_size;
}
CHECK(sample_offset == tot_samples);
std::vector<std::vector<kaldi::BaseFloat>> prob_vec;
while (1) {
std::vector<kaldi::BaseFloat> logprobs;
bool isok = nnet_producer->Read(&logprobs);
if (nnet_producer->IsFinished()) break;
if (isok == false) continue;
prob_vec.push_back(logprobs);
}
{
// writer nnet output
kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].size();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx];
}
}
nnet_out_writer.Write(utt, nnet_out);
}
nnet_producer->Reset();
}
nnet_producer->Wait();
double elapsed = timer.Elapsed();
LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,414 @@
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.cc
#include "nnet/u2_onnx_nnet.h"
#include "common/base/config.h"
namespace ppspeech {
void U2OnnxNnet::LoadModel(const std::string& model_dir) {
std::string encoder_onnx_path = model_dir + "/encoder.onnx";
std::string rescore_onnx_path = model_dir + "/decoder.onnx";
std::string ctc_onnx_path = model_dir + "/ctc.onnx";
std::string param_path = model_dir + "/param.onnx";
// 1. Load sessions
try {
encoder_ = std::make_shared<fastdeploy::Runtime>();
ctc_ = std::make_shared<fastdeploy::Runtime>();
rescore_ = std::make_shared<fastdeploy::Runtime>();
fastdeploy::RuntimeOption runtime_option;
runtime_option.UseOrtBackend();
runtime_option.UseCpu();
runtime_option.SetCpuThreadNum(1);
runtime_option.SetModelPath(encoder_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(encoder_->Init(runtime_option));
runtime_option.SetModelPath(rescore_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(rescore_->Init(runtime_option));
runtime_option.SetModelPath(ctc_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(ctc_->Init(runtime_option));
} catch (std::exception const& e) {
LOG(ERROR) << "error when load onnx model: " << e.what();
exit(0);
}
Config conf(param_path);
encoder_output_size_ = conf.Read("output_size", encoder_output_size_);
num_blocks_ = conf.Read("num_blocks", num_blocks_);
head_ = conf.Read("head", head_);
cnn_module_kernel_ = conf.Read("cnn_module_kernel", cnn_module_kernel_);
subsampling_rate_ = conf.Read("subsampling_rate", subsampling_rate_);
right_context_ = conf.Read("right_context", right_context_);
sos_= conf.Read("sos_symbol", sos_);
eos_= conf.Read("eos_symbol", eos_);
is_bidecoder_= conf.Read("is_bidirectional_decoder", is_bidecoder_);
chunk_size_= conf.Read("chunk_size", chunk_size_);
num_left_chunks_ = conf.Read("left_chunks", num_left_chunks_);
LOG(INFO) << "Onnx Model Info:";
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
LOG(INFO) << "\tnum_blocks " << num_blocks_;
LOG(INFO) << "\thead " << head_;
LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_;
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
LOG(INFO) << "\tright_context " << right_context_;
LOG(INFO) << "\tsos " << sos_;
LOG(INFO) << "\teos " << eos_;
LOG(INFO) << "\tis bidirectional decoder " << is_bidecoder_;
LOG(INFO) << "\tchunk_size " << chunk_size_;
LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_;
// 3. Read model nodes
LOG(INFO) << "Onnx Encoder:";
GetInputOutputInfo(encoder_, &encoder_in_names_, &encoder_out_names_);
LOG(INFO) << "Onnx CTC:";
GetInputOutputInfo(ctc_, &ctc_in_names_, &ctc_out_names_);
LOG(INFO) << "Onnx Rescore:";
GetInputOutputInfo(rescore_, &rescore_in_names_, &rescore_out_names_);
}
U2OnnxNnet::U2OnnxNnet(const ModelOptions& opts) : opts_(opts) {
LoadModel(opts_.model_path);
}
// shallow copy
U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) {
// metadatas
encoder_output_size_ = other.encoder_output_size_;
num_blocks_ = other.num_blocks_;
head_ = other.head_;
cnn_module_kernel_ = other.cnn_module_kernel_;
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidecoder_ = other.is_bidecoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// session
encoder_ = other.encoder_;
ctc_ = other.ctc_;
rescore_ = other.rescore_;
// node names
encoder_in_names_ = other.encoder_in_names_;
encoder_out_names_ = other.encoder_out_names_;
ctc_in_names_ = other.ctc_in_names_;
ctc_out_names_ = other.ctc_out_names_;
rescore_in_names_ = other.rescore_in_names_;
rescore_out_names_ = other.rescore_out_names_;
}
void U2OnnxNnet::GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names, std::vector<std::string>* out_names) {
std::vector<fastdeploy::TensorInfo> inputs_info = runtime->GetInputInfos();
(*in_names).resize(inputs_info.size());
for (int i = 0; i < inputs_info.size(); ++i){
fastdeploy::TensorInfo info = inputs_info[i];
std::stringstream shape;
for(int j = 0; j < info.shape.size(); ++j){
shape << info.shape[j];
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << info.name << " type=" << info.dtype
<< " dims=" << shape.str();
(*in_names)[i] = info.name;
}
std::vector<fastdeploy::TensorInfo> outputs_info = runtime->GetOutputInfos();
(*out_names).resize(outputs_info.size());
for (int i = 0; i < outputs_info.size(); ++i){
fastdeploy::TensorInfo info = outputs_info[i];
std::stringstream shape;
for(int j = 0; j < info.shape.size(); ++j){
shape << info.shape[j];
shape << " ";
}
LOG(INFO) << "\tOutput " << i << " : name=" << info.name << " type=" << info.dtype
<< " dims=" << shape.str();
(*out_names)[i] = info.name;
}
}
std::shared_ptr<NnetBase> U2OnnxNnet::Clone() const {
auto asr_model = std::make_shared<U2OnnxNnet>(*this);
// reset inner state for new decoding
asr_model->Reset();
return asr_model;
}
void U2OnnxNnet::Reset() {
offset_ = 0;
encoder_outs_.clear();
cached_feats_.clear();
// Reset att_cache
if (num_left_chunks_ > 0) {
int required_cache_size = chunk_size_ * num_left_chunks_;
offset_ = required_cache_size;
att_cache_.resize(num_blocks_ * head_ * required_cache_size *
encoder_output_size_ / head_ * 2,
0.0);
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, required_cache_size,
encoder_output_size_ / head_ * 2};
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
} else {
att_cache_.resize(0, 0.0);
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, 0,
encoder_output_size_ / head_ * 2};
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
}
// Reset cnn_cache
cnn_cache_.resize(
num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0);
const std::vector<int64_t> cnn_cache_shape = {num_blocks_, 1, encoder_output_size_,
cnn_module_kernel_ - 1};
cnn_cache_ort_.SetExternalData(cnn_cache_shape, fastdeploy::FDDataType::FP32, cnn_cache_.data());
}
void U2OnnxNnet::FeedForward(const std::vector<BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) {
kaldi::Timer timer;
std::vector<kaldi::BaseFloat> ctc_probs;
ForwardEncoderChunkImpl(
features, feature_dim, &out->logprobs, &out->vocab_dim);
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
<< features.size() / feature_dim << " frames.";
}
void U2OnnxNnet::ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) {
// 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
// chunk
int num_frames = chunk_feats.size() / feat_dim;
VLOG(3) << "num_frames: " << num_frames;
VLOG(3) << "feat_dim: " << feat_dim;
const int feature_dim = feat_dim;
std::vector<float> feats;
feats.insert(feats.end(), chunk_feats.begin(), chunk_feats.end());
fastdeploy::FDTensor feats_ort;
const std::vector<int64_t> feats_shape = {1, num_frames, feature_dim};
feats_ort.SetExternalData(feats_shape, fastdeploy::FDDataType::FP32, feats.data());
// offset
int64_t offset_int64 = static_cast<int64_t>(offset_);
fastdeploy::FDTensor offset_ort;
offset_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &offset_int64);
// required_cache_size
int64_t required_cache_size = chunk_size_ * num_left_chunks_;
fastdeploy::FDTensor required_cache_size_ort("");
required_cache_size_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &required_cache_size);
// att_mask
fastdeploy::FDTensor att_mask_ort;
std::vector<uint8_t> att_mask(required_cache_size + chunk_size_, 1);
if (num_left_chunks_ > 0) {
int chunk_idx = offset_ / chunk_size_ - num_left_chunks_;
if (chunk_idx < num_left_chunks_) {
for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
att_mask[i] = 0;
}
}
const std::vector<int64_t> att_mask_shape = {1, 1, required_cache_size + chunk_size_};
att_mask_ort.SetExternalData(att_mask_shape, fastdeploy::FDDataType::BOOL, reinterpret_cast<bool*>(att_mask.data()));
}
// 2. Encoder chunk forward
std::vector<fastdeploy::FDTensor> inputs(encoder_in_names_.size());
for (int i = 0; i < encoder_in_names_.size(); ++i) {
std::string name = encoder_in_names_[i];
if (!strcmp(name.data(), "chunk")) {
inputs[i] = std::move(feats_ort);
inputs[i].name = "chunk";
} else if (!strcmp(name.data(), "offset")) {
inputs[i] = std::move(offset_ort);
inputs[i].name = "offset";
} else if (!strcmp(name.data(), "required_cache_size")) {
inputs[i] = std::move(required_cache_size_ort);
inputs[i].name = "required_cache_size";
} else if (!strcmp(name.data(), "att_cache")) {
inputs[i] = std::move(att_cache_ort_);
inputs[i].name = "att_cache";
} else if (!strcmp(name.data(), "cnn_cache")) {
inputs[i] = std::move(cnn_cache_ort_);
inputs[i].name = "cnn_cache";
} else if (!strcmp(name.data(), "att_mask")) {
inputs[i] = std::move(att_mask_ort);
inputs[i].name = "att_mask";
}
}
std::vector<fastdeploy::FDTensor> ort_outputs;
assert(encoder_->Infer(inputs, &ort_outputs));
offset_ += static_cast<int>(ort_outputs[0].shape[1]);
att_cache_ort_ = std::move(ort_outputs[1]);
cnn_cache_ort_ = std::move(ort_outputs[2]);
std::vector<fastdeploy::FDTensor> ctc_inputs;
ctc_inputs.emplace_back(std::move(ort_outputs[0]));
// ctc_inputs[0] = std::move(ort_outputs[0]);
ctc_inputs[0].name = ctc_in_names_[0];
std::vector<fastdeploy::FDTensor> ctc_ort_outputs;
assert(ctc_->Infer(ctc_inputs, &ctc_ort_outputs));
encoder_outs_.emplace_back(std::move(ctc_inputs[0])); // *****
float* logp_data = reinterpret_cast<float*>(ctc_ort_outputs[0].Data());
// Copy to output, (B=1,T,D)
std::vector<int64_t> ctc_log_probs_shape = ctc_ort_outputs[0].shape;
CHECK_EQ(ctc_log_probs_shape.size(), 3);
int B = ctc_log_probs_shape[0];
CHECK_EQ(B, 1);
int T = ctc_log_probs_shape[1];
int D = ctc_log_probs_shape[2];
*vocab_dim = D;
out_prob->resize(T * D);
std::memcpy(
out_prob->data(), logp_data, T * D * sizeof(kaldi::BaseFloat));
return;
}
float U2OnnxNnet::ComputeAttentionScore(const float* prob,
const std::vector<int>& hyp, int eos,
int decode_out_len) {
float score = 0.0f;
for (size_t j = 0; j < hyp.size(); ++j) {
score += *(prob + j * decode_out_len + hyp[j]);
}
score += *(prob + hyp.size() * decode_out_len + eos);
return score;
}
void U2OnnxNnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);
if (num_hyps == 0) {
return;
}
// No encoder output
if (encoder_outs_.size() == 0) {
return;
}
std::vector<int64_t> hyps_lens;
int max_hyps_len = 0;
for (size_t i = 0; i < num_hyps; ++i) {
int length = hyps[i].size() + 1;
max_hyps_len = std::max(length, max_hyps_len);
hyps_lens.emplace_back(static_cast<int64_t>(length));
}
std::vector<float> rescore_input;
int encoder_len = 0;
for (int i = 0; i < encoder_outs_.size(); i++) {
float* encoder_outs_data = reinterpret_cast<float*>(encoder_outs_[i].Data());
for (int j = 0; j < encoder_outs_[i].Numel(); j++) {
rescore_input.emplace_back(encoder_outs_data[j]);
}
encoder_len += encoder_outs_[i].shape[1];
}
std::vector<int64_t> hyps_pad;
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
hyps_pad.emplace_back(sos_);
size_t j = 0;
for (; j < hyp.size(); ++j) {
hyps_pad.emplace_back(hyp[j]);
}
if (j == max_hyps_len - 1) {
continue;
}
for (; j < max_hyps_len - 1; ++j) {
hyps_pad.emplace_back(0);
}
}
const std::vector<int64_t> hyps_pad_shape = {num_hyps, max_hyps_len};
const std::vector<int64_t> hyps_lens_shape = {num_hyps};
const std::vector<int64_t> decode_input_shape = {1, encoder_len, encoder_output_size_};
fastdeploy::FDTensor hyps_pad_tensor_;
hyps_pad_tensor_.SetExternalData(hyps_pad_shape, fastdeploy::FDDataType::INT64, hyps_pad.data());
fastdeploy::FDTensor hyps_lens_tensor_;
hyps_lens_tensor_.SetExternalData(hyps_lens_shape, fastdeploy::FDDataType::INT64, hyps_lens.data());
fastdeploy::FDTensor decode_input_tensor_;
decode_input_tensor_.SetExternalData(decode_input_shape, fastdeploy::FDDataType::FP32, rescore_input.data());
std::vector<fastdeploy::FDTensor> rescore_inputs(3);
rescore_inputs[0] = std::move(hyps_pad_tensor_);
rescore_inputs[0].name = rescore_in_names_[0];
rescore_inputs[1] = std::move(hyps_lens_tensor_);
rescore_inputs[1].name = rescore_in_names_[1];
rescore_inputs[2] = std::move(decode_input_tensor_);
rescore_inputs[2].name = rescore_in_names_[2];
std::vector<fastdeploy::FDTensor> rescore_outputs;
assert(rescore_->Infer(rescore_inputs, &rescore_outputs));
float* decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[0].Data());
float* r_decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[1].Data());
int decode_out_len = rescore_outputs[0].shape[2];
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
float score = 0.0f;
// left to right decoder score
score = ComputeAttentionScore(
decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_,
decode_out_len);
// Optional: Used for right to left score
float r_score = 0.0f;
if (is_bidecoder_ && reverse_weight > 0) {
std::vector<int> r_hyp(hyp.size());
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
// right to left decoder score
r_score = ComputeAttentionScore(
r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_,
decode_out_len);
}
// combined left-to-right and right-to-left score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
}
}
void U2OnnxNnet::EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
}
} //namepace ppspeech

@ -0,0 +1,97 @@
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h
#pragma once
#include "base/common.h"
#include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "nnet/u2_nnet.h"
#include "fastdeploy/runtime.h"
namespace ppspeech {
class U2OnnxNnet : public U2NnetBase {
public:
explicit U2OnnxNnet(const ModelOptions& opts);
U2OnnxNnet(const U2OnnxNnet& other);
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
void Reset() override;
bool IsLogProb() override { return true; }
void Dim();
void LoadModel(const std::string& model_dir);
std::shared_ptr<NnetBase> Clone() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) override;
float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
int eos, int decode_out_len);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
void EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
void GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names,
std::vector<std::string>* out_names);
private:
ModelOptions opts_;
int encoder_output_size_ = 0;
int num_blocks_ = 0;
int cnn_module_kernel_ = 0;
int head_ = 0;
// sessions
std::shared_ptr<fastdeploy::Runtime> encoder_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> rescore_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> ctc_ = nullptr;
// node names
std::vector<std::string> encoder_in_names_, encoder_out_names_;
std::vector<std::string> ctc_in_names_, ctc_out_names_;
std::vector<std::string> rescore_in_names_, rescore_out_names_;
// caches
fastdeploy::FDTensor att_cache_ort_;
fastdeploy::FDTensor cnn_cache_ort_;
std::vector<fastdeploy::FDTensor> encoder_outs_;
std::vector<float> att_cache_;
std::vector<float> cnn_cache_;
};
} // namespace ppspeech

@ -0,0 +1,26 @@
set(srcs)
list(APPEND srcs
recognizer_controller.cc
recognizer_controller_impl.cc
recognizer_instance.cc
recognizer.cc
)
add_library(recognizer STATIC ${srcs})
target_link_libraries(recognizer PUBLIC decoder)
set(TEST_BINS
recognizer_batch_main
recognizer_batch_main2
recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
endforeach()

@ -0,0 +1,46 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer.h"
#include "recognizer/recognizer_instance.h"
bool InitRecognizer(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance) {
return ppspeech::RecognizerInstance::GetInstance().Init(model_file,
word_symbol_table_file,
fst_file,
num_instance);
}
int GetRecognizerInstanceId() {
return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId();
}
void InitDecoder(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id);
}
void AcceptData(const std::vector<float>& waves, int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id);
}
void SetInputFinished(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id);
}
std::string GetFinalResult(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id);
}

@ -0,0 +1,28 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
bool InitRecognizer(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance);
int GetRecognizerInstanceId();
void InitDecoder(int instance_id);
void AcceptData(const std::vector<float>& waves, int instance_id);
void SetInputFinished(int instance_id);
std::string GetFinalResult(int instance_id);

@ -0,0 +1,172 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id == -1) {
recog_id = recognizer_controller->GetRecognizerInstanceId();
}
recognizer_controller->InitDecoder(recog_id);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
recognizer_controller->Accept(wav_chunk, recog_id);
// no overlap
sample_offset += cur_chunk_size;
}
recognizer_controller->SetInputFinished(recog_id);
CHECK(sample_offset == tot_samples);
std::string result = recognizer_controller->GetFinalResult(recog_id);
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
result = " ";
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
results->push_back(result);
++num_done;
}
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
ppspeech::RecognizerController recognizer_controller(njob, resource);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
&recognizer_controller,
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}

@ -0,0 +1,168 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
void recognizer_func(std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id == -1) {
recog_id = GetRecognizerInstanceId();
}
InitDecoder(recog_id);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
AcceptData(wav_chunk, recog_id);
// no overlap
sample_offset += cur_chunk_size;
}
SetInputFinished(recog_id);
CHECK(sample_offset == tot_samples);
std::string result = GetFinalResult(recog_id);
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
result = " ";
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
results->push_back(result);
++num_done;
}
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}

@ -0,0 +1,70 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_controller.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) {
recognizer_workers.resize(num_worker);
for (size_t i = 0; i < num_worker; ++i) {
recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource));
waiting_workers.push(i);
}
}
int RecognizerController::GetRecognizerInstanceId() {
if (waiting_workers.empty()) {
return -1;
}
int idx = -1;
{
std::unique_lock<std::mutex> lock(mutex_);
idx = waiting_workers.front();
waiting_workers.pop();
}
return idx;
}
RecognizerController::~RecognizerController() {
for (size_t i = 0; i < recognizer_workers.size(); ++i) {
recognizer_workers[i]->WaitFinished();
}
}
void RecognizerController::InitDecoder(int idx) {
recognizer_workers[idx]->InitDecoder();
}
std::string RecognizerController::GetFinalResult(int idx) {
recognizer_workers[idx]->WaitDecoderFinished();
recognizer_workers[idx]->AttentionRescoring();
std::string result = recognizer_workers[idx]->GetFinalResult();
{
std::unique_lock<std::mutex> lock(mutex_);
waiting_workers.push(idx);
}
return result;
}
void RecognizerController::Accept(std::vector<float> data, int idx) {
recognizer_workers[idx]->Accept(data);
}
void RecognizerController::SetInputFinished(int idx) {
recognizer_workers[idx]->SetInputFinished();
}
}

@ -0,0 +1,42 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <memory>
#include "recognizer/recognizer_controller_impl.h"
namespace ppspeech {
class RecognizerController {
public:
explicit RecognizerController(int num_worker, RecognizerResource resource);
~RecognizerController();
int GetRecognizerInstanceId();
void InitDecoder(int idx);
void Accept(std::vector<float> data, int idx);
void SetInputFinished(int idx);
std::string GetFinalResult(int idx);
private:
std::queue<int> waiting_workers;
std::mutex mutex_;
std::vector<std::unique_ptr<ppspeech::RecognizerControllerImpl>> recognizer_workers;
DISALLOW_COPY_AND_ASSIGN(RecognizerController);
};
}

@ -1,4 +1,4 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,86 +12,180 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer_controller_impl.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "common/utils/strings.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource)
: opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale;
BaseFloat blank_threshold = resource.blank_threshold;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<FeaturePipeline> feature_pipeline(
new FeaturePipeline(feature_opts));
std::shared_ptr<NnetBase> nnet;
#ifndef USE_ONNX
nnet = resource.nnet->Clone();
#else
if (resource.model_opts.with_onnx_model){
nnet.reset(new U2OnnxNnet(resource.model_opts));
} else {
nnet = resource.nnet->Clone();
}
#endif
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold));
nnet_thread_ = std::thread(RunNnetEvaluation, this);
std::shared_ptr<NnetBase> nnet(new U2Nnet(resource.model_opts));
decodable_.reset(new Decodable(nnet_producer_, am_scale));
if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) {
LOG(INFO) << "Init PrefixBeamSearch Decoder";
decoder_ = std::make_unique<CTCPrefixBeamSearch>(
resource.decoder_opts.ctc_prefix_search_opts);
} else {
LOG(INFO) << "Init TLGDecoder";
decoder_ = std::make_unique<TLGDecoder>(
resource.decoder_opts.tlg_decoder_opts);
}
BaseFloat am_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
symbol_table_ = decoder_->WordSymbolTable();
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
}
CHECK_NE(resource.vocab_path, "");
decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
RecognizerControllerImpl::~RecognizerControllerImpl() {
WaitFinished();
}
unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_;
void RecognizerControllerImpl::Reset() {
nnet_producer_->Reset();
}
input_finished_ = false;
void RecognizerControllerImpl::RunDecoder(RecognizerControllerImpl* me) {
me->RunDecoderInternal();
}
Reset();
void RecognizerControllerImpl::RunDecoderInternal() {
LOG(INFO) << "DecoderInternal begin";
while (!nnet_producer_->IsFinished()) {
nnet_condition_.notify_one();
decoder_->AdvanceDecode(decodable_);
}
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
LOG(INFO) << "DecoderInternal exit";
}
void U2Recognizer::Reset() {
global_frame_offset_ = 0;
num_frames_ = 0;
result_.clear();
void RecognizerControllerImpl::WaitDecoderFinished() {
if (decoder_thread_.joinable()) decoder_thread_.join();
}
decodable_->Reset();
decoder_->Reset();
void RecognizerControllerImpl::RunNnetEvaluation(RecognizerControllerImpl* me) {
me->RunNnetEvaluationInternal();
}
void RecognizerControllerImpl::SetInputFinished() {
nnet_producer_->SetInputFinished();
nnet_condition_.notify_one();
LOG(INFO) << "Set Input Finished";
}
void RecognizerControllerImpl::WaitFinished() {
abort_ = true;
LOG(INFO) << "nnet wait finished";
nnet_condition_.notify_one();
if (nnet_thread_.joinable()) {
nnet_thread_.join();
}
}
void RecognizerControllerImpl::RunNnetEvaluationInternal() {
bool result = false;
LOG(INFO) << "NnetEvaluationInteral begin";
while (!abort_) {
std::unique_lock<std::mutex> lock(nnet_mutex_);
nnet_condition_.wait(lock);
do {
result = nnet_producer_->Compute();
decoder_condition_.notify_one();
} while (result);
}
LOG(INFO) << "NnetEvaluationInteral exit";
}
void U2Recognizer::ResetContinuousDecoding() {
global_frame_offset_ = num_frames_;
void RecognizerControllerImpl::Accept(std::vector<float> data) {
nnet_producer_->Accept(data);
nnet_condition_.notify_one();
}
void RecognizerControllerImpl::InitDecoder() {
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
decodable_->Reset();
decoder_->Reset();
decoder_thread_ = std::thread(RunDecoder, this);
}
void RecognizerControllerImpl::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(false);
void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
kaldi::Timer timer;
feature_pipeline_->Accept(waves);
VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim()
<< " samples.";
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
void U2Recognizer::Decode() {
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
void U2Recognizer::Rescoring() {
// Do attention Rescoring
AttentionRescoring();
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
void U2Recognizer::UpdateResult(bool finish) {
std::string RecognizerControllerImpl::GetFinalResult() { return result_[0].sentence; }
std::string RecognizerControllerImpl::GetPartialResult() { return result_[0].sentence; }
void RecognizerControllerImpl::UpdateResult(bool finish) {
const auto& hypotheses = decoder_->Outputs();
const auto& inputs = decoder_->Inputs();
const auto& likelihood = decoder_->Likelihood();
const auto& times = decoder_->Times();
result_.clear();
CHECK_EQ(hypotheses.size(), likelihood.size());
CHECK_EQ(inputs.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i];
@ -99,21 +193,16 @@ void U2Recognizer::UpdateResult(bool finish) {
path.score = likelihood[i];
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (decoder_->Type() == kWfstBeamSearch) {
path.sentence += (" " + word);
} else {
path.sentence += (word);
}
}
path.sentence = DelBlank(path.sentence);
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we
// use time stamp of the input(e2e model unit), which is more accurate,
// and it requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
if (symbol_table_ != nullptr && finish) {
int offset = global_frame_offset_ * FrameShiftInMs();
const std::vector<int>& input = inputs[i];
@ -121,7 +210,7 @@ void U2Recognizer::UpdateResult(bool finish) {
CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]);
std::string word = symbol_table_->Find(input[j]);
int start =
time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
@ -163,56 +252,4 @@ void U2Recognizer::UpdateResult(bool finish) {
}
}
void U2Recognizer::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(true);
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
} // namespace ppspeech

@ -0,0 +1,89 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "nnet/u2_nnet.h"
#include "nnet/nnet_producer.h"
#ifdef USE_ONNX
#include "nnet/u2_onnx_nnet.h"
#endif
#include "nnet/decodable.h"
#include "recognizer/recognizer_resource.h"
#include <memory>
namespace ppspeech {
class RecognizerControllerImpl {
public:
explicit RecognizerControllerImpl(const RecognizerResource& resource);
~RecognizerControllerImpl();
void Accept(std::vector<float> data);
void InitDecoder();
void SetInputFinished();
std::string GetFinalResult();
std::string GetPartialResult();
void Rescoring();
void Reset();
void WaitDecoderFinished();
void WaitFinished();
void AttentionRescoring();
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
}
int FrameShiftInMs() const {
return 1; //todo
}
private:
static void RunNnetEvaluation(RecognizerControllerImpl* me);
void RunNnetEvaluationInternal();
static void RunDecoder(RecognizerControllerImpl* me);
void RunDecoderInternal();
void UpdateResult(bool finish = false);
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<DecoderBase> decoder_;
std::shared_ptr<NnetProducer> nnet_producer_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::vector<DecodeResult> result_;
RecognizerResource opts_;
bool abort_ = false;
// global decoded frame offset
int global_frame_offset_;
// cur decoded frame num
int num_frames_;
// timestamp gap between words in a sentence
const int time_stamp_gap_ = 100;
bool input_finished_;
std::mutex nnet_mutex_;
std::mutex decoder_mutex_;
std::condition_variable nnet_condition_;
std::condition_variable decoder_condition_;
std::thread nnet_thread_;
std::thread decoder_thread_;
DISALLOW_COPY_AND_ASSIGN(RecognizerControllerImpl);
};
}

@ -0,0 +1,66 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_instance.h"
namespace ppspeech {
RecognizerInstance& RecognizerInstance::GetInstance() {
static RecognizerInstance instance;
return instance;
}
bool RecognizerInstance::Init(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance) {
RecognizerResource resource = RecognizerResource::InitFromFlags();
resource.model_opts.model_path = model_file;
//resource.vocab_path = word_symbol_table_file;
if (!fst_file.empty()) {
resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file;
resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file;
} else {
resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table =
word_symbol_table_file;
}
recognizer_controller_ = std::make_unique<RecognizerController>(num_instance, resource);
return true;
}
void RecognizerInstance::InitDecoder(int idx) {
recognizer_controller_->InitDecoder(idx);
return;
}
int RecognizerInstance::GetRecognizerInstanceId() {
return recognizer_controller_->GetRecognizerInstanceId();
}
void RecognizerInstance::Accept(const std::vector<float>& waves, int idx) const {
recognizer_controller_->Accept(waves, idx);
return;
}
void RecognizerInstance::SetInputFinished(int idx) const {
recognizer_controller_->SetInputFinished(idx);
return;
}
std::string RecognizerInstance::GetResult(int idx) const {
return recognizer_controller_->GetFinalResult(idx);
}
}

@ -0,0 +1,42 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "recognizer/recognizer_controller.h"
namespace ppspeech {
class RecognizerInstance {
public:
static RecognizerInstance& GetInstance();
RecognizerInstance() {}
~RecognizerInstance() {}
bool Init(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance);
int GetRecognizerInstanceId();
void InitDecoder(int idx);
void Accept(const std::vector<float>& waves, int idx) const;
void SetInputFinished(int idx) const;
std::string GetResult(int idx) const;
private:
std::unique_ptr<RecognizerController> recognizer_controller_;
};
} // namespace ppspeech

@ -13,9 +13,9 @@
// limitations under the License.
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "recognizer/u2_recognizer.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
@ -31,6 +31,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
@ -44,11 +45,13 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource);
ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
std::shared_ptr<ppspeech::RecognizerControllerImpl> recognizer_ptr(
new ppspeech::RecognizerControllerImpl(resource));
for (; !wav_reader.Done(); wav_reader.Next()) {
recognizer_ptr->InitDecoder();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
@ -63,45 +66,32 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
int cnt = 0;
kaldi::Timer timer;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
wav_chunk[i] = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
}
recognizer.Decode();
if (recognizer.DecodedSomething()) {
LOG(INFO) << "Pratial result: " << cnt << " "
<< recognizer.GetPartialResult();
}
recognizer_ptr->Accept(wav_chunk);
// no overlap
sample_offset += cur_chunk_size;
cnt++;
}
CHECK(sample_offset == tot_samples);
recognizer_ptr->SetInputFinished();
recognizer_ptr->WaitDecoderFinished();
// second pass decoding
recognizer.Rescoring();
tot_decode_time += timer.Elapsed();
std::string result = recognizer.GetFinalResult();
recognizer.Reset();
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();
float rescore_time = timer.Elapsed();
tot_attention_rescore_time += rescore_time;
std::string result = recognizer_ptr->GetFinalResult();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
@ -109,17 +99,20 @@ int main(int argc, char* argv[]) {
continue;
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
<< " cost: " << local_timer.Elapsed() << " rescore:" << rescore_time;
result_writer.Write(utt, result);
++num_done;
}
recognizer_ptr->WaitFinished();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}

@ -1,27 +1,8 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "decoder/decoder_itf.h"
#include "frontend/audio/feature_pipeline.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "nnet/decodable.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/feature_pipeline.h"
DECLARE_int32(nnet_decoder_chunk);
DECLARE_int32(num_left_chunks);
@ -30,9 +11,9 @@ DECLARE_double(rescoring_weight);
DECLARE_double(reverse_weight);
DECLARE_int32(nbest);
DECLARE_int32(blank);
DECLARE_double(acoustic_scale);
DECLARE_string(vocab_path);
DECLARE_double(blank_threshold);
DECLARE_string(word_symbol_table);
namespace ppspeech {
@ -59,6 +40,7 @@ struct DecodeOptions {
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions ctc_prefix_search_opts{};
TLGDecoderOptions tlg_decoder_opts{};
static DecodeOptions InitFromFlags() {
DecodeOptions decoder_opts;
@ -70,6 +52,11 @@ struct DecodeOptions {
decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank;
decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
decoder_opts.ctc_prefix_search_opts.word_symbol_table =
FLAGS_word_symbol_table;
decoder_opts.tlg_decoder_opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
LOG(INFO) << "chunk_size: " << decoder_opts.chunk_size;
LOG(INFO) << "num_left_chunks: " << decoder_opts.num_left_chunks;
LOG(INFO) << "ctc_weight: " << decoder_opts.ctc_weight;
@ -82,19 +69,20 @@ struct DecodeOptions {
}
};
struct U2RecognizerResource {
struct RecognizerResource {
// decodable opt
kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{};
kaldi::BaseFloat blank_threshold{0.98};
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
DecodeOptions decoder_opts{};
std::shared_ptr<NnetBase> nnet;
static U2RecognizerResource InitFromFlags() {
U2RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
static RecognizerResource InitFromFlags() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
LOG(INFO) << "vocab path: " << resource.vocab_path;
resource.blank_threshold = FLAGS_blank_threshold;
LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale;
resource.feature_pipeline_opts =
@ -104,69 +92,17 @@ struct U2RecognizerResource {
<< resource.feature_pipeline_opts.assembler_opts.fill_zero;
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags();
return resource;
}
};
class U2Recognizer {
public:
explicit U2Recognizer(const U2RecognizerResource& resouce);
void Reset();
void ResetContinuousDecoding();
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
void Decode();
void Rescoring();
std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished();
bool IsFinished() { return input_finished_; }
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
#ifndef USE_ONNX
resource.nnet.reset(new U2Nnet(resource.model_opts));
#else
if (resource.model_opts.with_onnx_model){
resource.nnet.reset(new U2OnnxNnet(resource.model_opts));
} else {
resource.nnet.reset(new U2Nnet(resource.model_opts));
}
int FrameShiftInMs() const {
// one decoder frame length in ms
return decodable_->Nnet()->SubsamplingRate() *
feature_pipeline_->FrameShift();
#endif
return resource;
}
const std::vector<DecodeResult>& Result() const { return result_; }
private:
void AttentionRescoring();
void UpdateResult(bool finish = false);
private:
U2RecognizerResource opts_;
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<CTCPrefixBeamSearch> decoder_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::vector<DecodeResult> result_;
// global decoded frame offset
int global_frame_offset_;
// cur decoded frame num
int num_frames_;
// timestamp gap between words in a sentence
const int time_stamp_gap_ = 100;
bool input_finished_;
};
} //namespace ppspeech

@ -0,0 +1 @@
#add_subdirectory(websocket)

@ -0,0 +1,3 @@
# add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND")
add_definitions("-DUSE_ORT_BACKEND")
add_subdirectory(nnet)

@ -0,0 +1,11 @@
set(srcs
panns_nnet.cc
panns_interface.cc
)
add_library(cls SHARED ${srcs})
target_link_libraries(cls PRIVATE ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils )
set(bin_name panns_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_link_libraries(${bin_name} gflags glog cls)

@ -0,0 +1,79 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "audio_classification/nnet/panns_interface.h"
#include "audio_classification/nnet/panns_nnet.h"
#include "common/base/config.h"
namespace ppspeech {
void* ClsCreateInstance(const char* conf_path) {
Config conf(conf_path);
// cls init
ppspeech::ClsNnetConf cls_nnet_conf;
cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true);
cls_nnet_conf.wav_normal_type_ =
conf.Read("wav_normal_type", std::string("linear"));
cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0);
cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string(""));
cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string(""));
cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string(""));
cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12);
cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000);
cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32);
cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10);
cls_nnet_conf.num_bins = conf.Read("num_bins", 64);
cls_nnet_conf.low_freq = conf.Read("low_freq", 50);
cls_nnet_conf.high_freq = conf.Read("high_freq", 14000);
cls_nnet_conf.dither = conf.Read("dither", 0.0);
ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet();
int ret = cls_model->Init(cls_nnet_conf);
return static_cast<void*>(cls_model);
}
int ClsDestroyInstance(void* instance) {
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
if (cls_model != NULL) {
delete cls_model;
cls_model = NULL;
}
return 0;
}
int ClsFeedForward(void* instance,
const char* wav_path,
int topk,
char* result,
int result_max_len) {
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
if (cls_model == NULL) {
printf("instance is null\n");
return -1;
}
int ret = cls_model->Forward(wav_path, topk, result, result_max_len);
return 0;
}
int ClsReset(void* instance) {
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
if (cls_model == NULL) {
printf("instance is null\n");
return -1;
}
cls_model->Reset();
return 0;
}
} // namespace ppspeech

@ -0,0 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace ppspeech {
void* ClsCreateInstance(const char* conf_path);
int ClsDestroyInstance(void* instance);
int ClsFeedForward(void* instance,
const char* wav_path,
int topk,
char* result,
int result_max_len);
int ClsReset(void* instance);
} // namespace ppspeech

@ -0,0 +1,227 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "audio_classification/nnet/panns_nnet.h"
#ifdef WITH_PROFILING
#include "kaldi/base/timer.h"
#endif
namespace ppspeech {
ClsNnet::ClsNnet() {
// wav_reader_ = NULL;
runtime_ = NULL;
}
void ClsNnet::Reset() {
// wav_reader_->Clear();
ss_.str("");
}
int ClsNnet::Init(const ClsNnetConf& conf) {
conf_ = conf;
// init fbank opts
fbank_opts_.frame_opts.samp_freq = conf.samp_freq;
fbank_opts_.frame_opts.frame_length_ms = conf.frame_length_ms;
fbank_opts_.frame_opts.frame_shift_ms = conf.frame_shift_ms;
fbank_opts_.mel_opts.num_bins = conf.num_bins;
fbank_opts_.mel_opts.low_freq = conf.low_freq;
fbank_opts_.mel_opts.high_freq = conf.high_freq;
fbank_opts_.frame_opts.dither = conf.dither;
fbank_opts_.use_log_fbank = false;
// init dict
if (conf.dict_file_path_ != "") {
ReadFileToVector(conf.dict_file_path_, &dict_);
}
// init model
fastdeploy::RuntimeOption runtime_option;
#ifdef USE_PADDLE_INFERENCE_BACKEND
runtime_option.SetModelPath(conf.model_file_path_,
conf.param_file_path_,
fastdeploy::ModelFormat::PADDLE);
runtime_option.UsePaddleInferBackend();
#elif defined(USE_ORT_BACKEND)
runtime_option.SetModelPath(
conf.model_file_path_, "", fastdeploy::ModelFormat::ONNX); // onnx
runtime_option.UseOrtBackend(); // onnx
#elif defined(USE_PADDLE_LITE_BACKEND)
runtime_option.SetModelPath(conf.model_file_path_,
conf.param_file_path_,
fastdeploy::ModelFormat::PADDLE);
runtime_option.UseLiteBackend();
#endif
runtime_option.SetCpuThreadNum(conf.num_cpu_thread_);
// runtime_option.DeletePaddleBackendPass("simplify_with_basic_ops_pass");
runtime_ = std::unique_ptr<fastdeploy::Runtime>(new fastdeploy::Runtime());
if (!runtime_->Init(runtime_option)) {
std::cerr << "--- Init FastDeploy Runitme Failed! "
<< "\n--- Model: " << conf.model_file_path_ << std::endl;
return -1;
} else {
std::cout << "--- Init FastDeploy Runitme Done! "
<< "\n--- Model: " << conf.model_file_path_ << std::endl;
}
Reset();
return 0;
}
int ClsNnet::Forward(const char* wav_path,
int topk,
char* result,
int result_max_len) {
#ifdef WITH_PROFILING
kaldi::Timer timer;
timer.Reset();
#endif
// read wav
std::ifstream infile(wav_path, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 this_channel = 0;
kaldi::Matrix<float> wavform_kaldi = wave_data.Data();
// only get channel 0
int wavform_len = wavform_kaldi.NumCols();
std::vector<float> wavform(wavform_kaldi.Data(),
wavform_kaldi.Data() + wavform_len);
WaveformFloatNormal(&wavform);
WaveformNormal(&wavform,
conf_.wav_normal_,
conf_.wav_normal_type_,
conf_.wav_norm_mul_factor_);
#ifdef PPS_DEBUG
{
std::ofstream fp("cls.wavform", std::ios::out);
for (int i = 0; i < wavform.size(); ++i) {
fp << std::setprecision(18) << wavform[i] << " ";
}
fp << "\n";
}
#endif
#ifdef WITH_PROFILING
printf("wav read consume: %fs\n", timer.Elapsed());
#endif
#ifdef WITH_PROFILING
timer.Reset();
#endif
std::vector<float> feats;
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::DataCache());
ppspeech::Fbank fbank(fbank_opts_, std::move(data_source));
fbank.Accept(wavform);
fbank.SetFinished();
fbank.Read(&feats);
int feat_dim = fbank_opts_.mel_opts.num_bins;
int num_frames = feats.size() / feat_dim;
for (int i = 0; i < num_frames; ++i) {
for (int j = 0; j < feat_dim; ++j) {
feats[i * feat_dim + j] = PowerTodb(feats[i * feat_dim + j]);
}
}
#ifdef PPS_DEBUG
{
std::ofstream fp("cls.feat", std::ios::out);
for (int i = 0; i < num_frames; ++i) {
for (int j = 0; j < feat_dim; ++j) {
fp << std::setprecision(18) << feats[i * feat_dim + j] << " ";
}
fp << "\n";
}
}
#endif
#ifdef WITH_PROFILING
printf("extract fbank consume: %fs\n", timer.Elapsed());
#endif
// infer
std::vector<float> model_out;
#ifdef WITH_PROFILING
timer.Reset();
#endif
ModelForward(feats.data(), num_frames, feat_dim, &model_out);
#ifdef WITH_PROFILING
printf("fast deploy infer consume: %fs\n", timer.Elapsed());
#endif
#ifdef PPS_DEBUG
{
std::ofstream fp("cls.logits", std::ios::out);
for (int i = 0; i < model_out.size(); ++i) {
fp << std::setprecision(18) << model_out[i] << "\n";
}
}
#endif
// construct result str
ss_ << "{";
GetTopkResult(topk, model_out);
ss_ << "}";
if (result_max_len <= ss_.str().size()) {
printf("result_max_len is short than result len\n");
}
snprintf(result, result_max_len, "%s", ss_.str().c_str());
return 0;
}
int ClsNnet::ModelForward(float* features,
const int num_frames,
const int feat_dim,
std::vector<float>* model_out) {
// init input tensor shape
fastdeploy::TensorInfo info = runtime_->GetInputInfo(0);
info.shape = {1, num_frames, feat_dim};
std::vector<fastdeploy::FDTensor> input_tensors(1);
std::vector<fastdeploy::FDTensor> output_tensors(1);
input_tensors[0].SetExternalData({1, num_frames, feat_dim},
fastdeploy::FDDataType::FP32,
static_cast<void*>(features));
// get input name
input_tensors[0].name = info.name;
runtime_->Infer(input_tensors, &output_tensors);
// output_tensors[0].PrintInfo();
std::vector<int64_t> output_shape = output_tensors[0].Shape();
model_out->resize(output_shape[0] * output_shape[1]);
memcpy(static_cast<void*>(model_out->data()),
output_tensors[0].Data(),
output_shape[0] * output_shape[1] * sizeof(float));
return 0;
}
int ClsNnet::GetTopkResult(int k, const std::vector<float>& model_out) {
std::vector<float> values;
std::vector<int> indics;
TopK(model_out, k, &values, &indics);
for (int i = 0; i < k; ++i) {
if (i != 0) {
ss_ << ",";
}
ss_ << "\"" << dict_[indics[i]] << "\":\"" << values[i] << "\"";
}
return 0;
}
} // namespace ppspeech

@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common/frontend/data_cache.h"
#include "common/frontend/fbank.h"
#include "common/frontend/feature-fbank.h"
#include "common/frontend/frontend_itf.h"
#include "common/frontend/wave-reader.h"
#include "common/utils/audio_process.h"
#include "common/utils/file_utils.h"
#include "fastdeploy/runtime.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
namespace ppspeech {
struct ClsNnetConf {
// wav
bool wav_normal_;
std::string wav_normal_type_;
float wav_norm_mul_factor_;
// model
std::string model_file_path_;
std::string param_file_path_;
std::string dict_file_path_;
int num_cpu_thread_;
// fbank
float samp_freq;
float frame_length_ms;
float frame_shift_ms;
int num_bins;
float low_freq;
float high_freq;
float dither;
};
class ClsNnet {
public:
ClsNnet();
int Init(const ClsNnetConf& conf);
int Forward(const char* wav_path,
int topk,
char* result,
int result_max_len);
void Reset();
private:
int ModelForward(float* features,
const int num_frames,
const int feat_dim,
std::vector<float>* model_out);
int ModelForwardStream(std::vector<float>* feats);
int GetTopkResult(int k, const std::vector<float>& model_out);
ClsNnetConf conf_;
knf::FbankOptions fbank_opts_;
std::unique_ptr<fastdeploy::Runtime> runtime_;
std::vector<std::string> dict_;
std::stringstream ss_;
};
} // namespace ppspeech

@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <string>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "audio_classification/nnet/panns_interface.h"
DEFINE_string(conf_path, "", "config path");
DEFINE_string(scp_path, "", "wav scp path");
DEFINE_string(topk, "", "print topk results");
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK_GT(FLAGS_conf_path.size(), 0);
CHECK_GT(FLAGS_scp_path.size(), 0);
CHECK_GT(FLAGS_topk.size(), 0);
void* instance = ppspeech::ClsCreateInstance(FLAGS_conf_path.c_str());
int ret = 0;
// read wav
std::ifstream ifs(FLAGS_scp_path);
std::string line = "";
int topk = std::atoi(FLAGS_topk.c_str());
while (getline(ifs, line)) {
// read wav
char result[1024] = {0};
ret = ppspeech::ClsFeedForward(
instance, line.c_str(), topk, result, 1024);
printf("%s %s\n", line.c_str(), result);
ret = ppspeech::ClsReset(instance);
}
ret = ppspeech::ClsDestroyInstance(instance);
return 0;
}

@ -0,0 +1,6 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
if(ANDROID)
else() #Unix
add_subdirectory(glog)
endif()

@ -1,8 +1,8 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc)
target_link_libraries(glog_main glog)
target_link_libraries(glog_main extern_glog)
add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc)
target_link_libraries(glog_logtostderr_main glog)
target_link_libraries(glog_logtostderr_main extern_glog)

@ -0,0 +1,19 @@
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../
)
add_subdirectory(base)
add_subdirectory(utils)
add_subdirectory(matrix)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/frontend
)
add_subdirectory(frontend)
add_library(common INTERFACE)
target_link_libraries(common INTERFACE base utils kaldi-matrix frontend)
install(TARGETS base DESTINATION lib)
install(TARGETS utils DESTINATION lib)
install(TARGETS kaldi-matrix DESTINATION lib)
install(TARGETS frontend DESTINATION lib)

@ -0,0 +1,43 @@
if(WITH_ASR)
add_compile_options(-DWITH_ASR)
set(PPS_FLAGS_LIB "fst/flags.h")
else()
set(PPS_FLAGS_LIB "gflags/gflags.h")
endif()
if(ANDROID)
set(PPS_GLOG_LIB "base/log_impl.h")
else() #UNIX
if(WITH_ASR)
set(PPS_GLOG_LIB "fst/log.h")
else()
set(PPS_GLOG_LIB "glog/logging.h")
endif()
endif()
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/flags.h.in
${CMAKE_CURRENT_SOURCE_DIR}/flags.h @ONLY
)
message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/flags.h")
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/log.h.in
${CMAKE_CURRENT_SOURCE_DIR}/log.h @ONLY
)
message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/log.h")
if(ANDROID)
set(csrc
log_impl.cc
glog_utils.cc
)
add_library(base ${csrc})
target_link_libraries(base gflags)
else() # UNIX
set(csrc)
add_library(base INTERFACE)
endif()

@ -21,6 +21,8 @@
#include <cstring>
#include <deque>
#include <fstream>
#include <functional>
#include <future>
#include <iomanip>
#include <iostream>
#include <istream>
@ -49,3 +51,4 @@
#include "base/macros.h"
#include "utils/file_utils.h"
#include "utils/math.h"
#include "utils/timer.h"

@ -0,0 +1,343 @@
// Copyright (c) code is from
// https://blog.csdn.net/huixingshao/article/details/45969887.
#include <fstream>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
using namespace std;
#pragma once
#ifdef _MSC_VER
#pragma region ParseIniFile
#endif
/*
* \brief Generic configuration Class
*
*/
class Config {
// Data
protected:
std::string m_Delimiter; //!< separator between key and value
std::string m_Comment; //!< separator between value and comments
std::map<std::string, std::string>
m_Contents; //!< extracted keys and values
typedef std::map<std::string, std::string>::iterator mapi;
typedef std::map<std::string, std::string>::const_iterator mapci;
// Methods
public:
Config(std::string filename,
std::string delimiter = "=",
std::string comment = "#");
Config();
template <class T>
T Read(const std::string& in_key) const; //!< Search for key and read value
//! or optional default value, call
//! as read<T>
template <class T>
T Read(const std::string& in_key, const T& in_value) const;
template <class T>
bool ReadInto(T* out_var, const std::string& in_key) const;
template <class T>
bool ReadInto(T* out_var,
const std::string& in_key,
const T& in_value) const;
bool FileExist(std::string filename);
void ReadFile(std::string filename,
std::string delimiter = "=",
std::string comment = "#");
// Check whether key exists in configuration
bool KeyExists(const std::string& in_key) const;
// Modify keys and values
template <class T>
void Add(const std::string& in_key, const T& in_value);
void Remove(const std::string& in_key);
// Check or change configuration syntax
std::string GetDelimiter() const { return m_Delimiter; }
std::string GetComment() const { return m_Comment; }
std::string SetDelimiter(const std::string& in_s) {
std::string old = m_Delimiter;
m_Delimiter = in_s;
return old;
}
std::string SetComment(const std::string& in_s) {
std::string old = m_Comment;
m_Comment = in_s;
return old;
}
// Write or read configuration
friend std::ostream& operator<<(std::ostream& os, const Config& cf);
friend std::istream& operator>>(std::istream& is, Config& cf);
protected:
template <class T>
static std::string T_as_string(const T& t);
template <class T>
static T string_as_T(const std::string& s);
static void Trim(std::string* inout_s);
// Exception types
public:
struct File_not_found {
std::string filename;
explicit File_not_found(const std::string& filename_ = std::string())
: filename(filename_) {}
};
struct Key_not_found { // thrown only by T read(key) variant of read()
std::string key;
explicit Key_not_found(const std::string& key_ = std::string())
: key(key_) {}
};
};
/* static */
template <class T>
std::string Config::T_as_string(const T& t) {
// Convert from a T to a string
// Type T must support << operator
std::ostringstream ost;
ost << t;
return ost.str();
}
/* static */
template <class T>
T Config::string_as_T(const std::string& s) {
// Convert from a string to a T
// Type T must support >> operator
T t;
std::istringstream ist(s);
ist >> t;
return t;
}
/* static */
template <>
inline std::string Config::string_as_T<std::string>(const std::string& s) {
// Convert from a string to a string
// In other words, do nothing
return s;
}
/* static */
template <>
inline bool Config::string_as_T<bool>(const std::string& s) {
// Convert from a string to a bool
// Interpret "false", "F", "no", "n", "0" as false
// Interpret "true", "T", "yes", "y", "1", "-1", or anything else as true
bool b = true;
std::string sup = s;
for (std::string::iterator p = sup.begin(); p != sup.end(); ++p)
*p = toupper(*p); // make string all caps
if (sup == std::string("FALSE") || sup == std::string("F") ||
sup == std::string("NO") || sup == std::string("N") ||
sup == std::string("0") || sup == std::string("NONE"))
b = false;
return b;
}
template <class T>
T Config::Read(const std::string& key) const {
// Read the value corresponding to key
mapci p = m_Contents.find(key);
if (p == m_Contents.end()) throw Key_not_found(key);
return string_as_T<T>(p->second);
}
template <class T>
T Config::Read(const std::string& key, const T& value) const {
// Return the value corresponding to key or given default value
// if key is not found
mapci p = m_Contents.find(key);
if (p == m_Contents.end()) {
printf("%s = %s(default)\n", key.c_str(), T_as_string(value).c_str());
return value;
} else {
printf("%s = %s\n", key.c_str(), T_as_string(p->second).c_str());
return string_as_T<T>(p->second);
}
}
template <class T>
bool Config::ReadInto(T* var, const std::string& key) const {
// Get the value corresponding to key and store in var
// Return true if key is found
// Otherwise leave var untouched
mapci p = m_Contents.find(key);
bool found = (p != m_Contents.end());
if (found) *var = string_as_T<T>(p->second);
return found;
}
template <class T>
bool Config::ReadInto(T* var, const std::string& key, const T& value) const {
// Get the value corresponding to key and store in var
// Return true if key is found
// Otherwise set var to given default
mapci p = m_Contents.find(key);
bool found = (p != m_Contents.end());
if (found)
*var = string_as_T<T>(p->second);
else
var = value;
return found;
}
template <class T>
void Config::Add(const std::string& in_key, const T& value) {
// Add a key with given value
std::string v = T_as_string(value);
std::string key = in_key;
Trim(&key);
Trim(&v);
m_Contents[key] = v;
return;
}
Config::Config(string filename, string delimiter, string comment)
: m_Delimiter(delimiter), m_Comment(comment) {
// Construct a Config, getting keys and values from given file
std::ifstream in(filename.c_str());
if (!in) throw File_not_found(filename);
in >> (*this);
}
Config::Config() : m_Delimiter(string(1, '=')), m_Comment(string(1, '#')) {
// Construct a Config without a file; empty
}
bool Config::KeyExists(const string& key) const {
// Indicate whether key is found
mapci p = m_Contents.find(key);
return (p != m_Contents.end());
}
/* static */
void Config::Trim(string* inout_s) {
// Remove leading and trailing whitespace
static const char whitespace[] = " \n\t\v\r\f";
inout_s->erase(0, inout_s->find_first_not_of(whitespace));
inout_s->erase(inout_s->find_last_not_of(whitespace) + 1U);
}
std::ostream& operator<<(std::ostream& os, const Config& cf) {
// Save a Config to os
for (Config::mapci p = cf.m_Contents.begin(); p != cf.m_Contents.end();
++p) {
os << p->first << " " << cf.m_Delimiter << " ";
os << p->second << std::endl;
}
return os;
}
void Config::Remove(const string& key) {
// Remove key and its value
m_Contents.erase(m_Contents.find(key));
return;
}
std::istream& operator>>(std::istream& is, Config& cf) {
// Load a Config from is
// Read in keys and values, keeping internal whitespace
typedef string::size_type pos;
const string& delim = cf.m_Delimiter; // separator
const string& comm = cf.m_Comment; // comment
const pos skip = delim.length(); // length of separator
string nextline = ""; // might need to read ahead to see where value ends
while (is || nextline.length() > 0) {
// Read an entire line at a time
string line;
if (nextline.length() > 0) {
line = nextline; // we read ahead; use it now
nextline = "";
} else {
std::getline(is, line);
}
// Ignore comments
line = line.substr(0, line.find(comm));
// Parse the line if it contains a delimiter
pos delimPos = line.find(delim);
if (delimPos < string::npos) {
// Extract the key
string key = line.substr(0, delimPos);
line.replace(0, delimPos + skip, "");
// See if value continues on the next line
// Stop at blank line, next line with a key, end of stream,
// or end of file sentry
bool terminate = false;
while (!terminate && is) {
std::getline(is, nextline);
terminate = true;
string nlcopy = nextline;
Config::Trim(&nlcopy);
if (nlcopy == "") continue;
nextline = nextline.substr(0, nextline.find(comm));
if (nextline.find(delim) != string::npos) continue;
nlcopy = nextline;
Config::Trim(&nlcopy);
if (nlcopy != "") line += "\n";
line += nextline;
terminate = false;
}
// Store key and value
Config::Trim(&key);
Config::Trim(&line);
cf.m_Contents[key] = line; // overwrites if key is repeated
}
}
return is;
}
bool Config::FileExist(std::string filename) {
bool exist = false;
std::ifstream in(filename.c_str());
if (in) exist = true;
return exist;
}
void Config::ReadFile(string filename, string delimiter, string comment) {
m_Delimiter = delimiter;
m_Comment = comment;
std::ifstream in(filename.c_str());
if (!in) throw File_not_found(filename);
in >> (*this);
}
#ifdef _MSC_VER
#pragma endregion ParseIniFIle
#endif

@ -14,4 +14,4 @@
#pragma once
#include "fst/log.h"
#include "@PPS_FLAGS_LIB@"

@ -0,0 +1,12 @@
#include "base/glog_utils.h"
namespace google {
void InitGoogleLogging(const char* name) {
LOG(INFO) << "dummpy InitGoogleLogging.";
}
void InstallFailureSignalHandler() {
LOG(INFO) << "dummpy InstallFailureSignalHandler.";
}
} // namespace google

@ -0,0 +1,9 @@
#pragma once
#include "base/common.h"
namespace google {
void InitGoogleLogging(const char* name);
void InstallFailureSignalHandler();
} // namespace google

@ -14,4 +14,4 @@
#pragma once
#include "fst/flags.h"
#include "@PPS_GLOG_LIB@"

@ -0,0 +1,105 @@
#include "base/log.h"
DEFINE_int32(logtostderr, 0, "logging to stderr");
namespace ppspeech {
static char __progname[] = "paddlespeech";
namespace log {
std::mutex LogMessage::lock_;
std::string LogMessage::s_debug_logfile_("");
std::string LogMessage::s_info_logfile_("");
std::string LogMessage::s_warning_logfile_("");
std::string LogMessage::s_error_logfile_("");
std::string LogMessage::s_fatal_logfile_("");
void LogMessage::get_curr_proc_info(std::string* pid, std::string* proc_name) {
std::stringstream ss;
ss << getpid();
ss >> *pid;
*proc_name = ::ppspeech::__progname;
}
LogMessage::LogMessage(const char* file,
int line,
Severity level,
bool verbose,
bool out_to_file /* = false */)
: level_(level), verbose_(verbose), out_to_file_(out_to_file) {
if (FLAGS_logtostderr == 0) {
stream_ = static_cast<std::ostream*>(&std::cout);
} else if (FLAGS_logtostderr == 1) {
stream_ = static_cast<std::ostream*>(&std::cerr);
} else if (out_to_file_) {
// logfile
lock_.lock();
init(file, line);
}
}
LogMessage::~LogMessage() {
stream() << std::endl;
if (out_to_file_) {
lock_.unlock();
}
if (verbose_ && level_ == FATAL) {
std::abort();
}
}
std::ostream* LogMessage::nullstream() {
thread_local static std::ofstream os;
thread_local static bool flag_set = false;
if (!flag_set) {
os.setstate(std::ios_base::badbit);
flag_set = true;
}
return &os;
}
void LogMessage::init(const char* file, int line) {
time_t t = time(0);
char tmp[100];
strftime(tmp, sizeof(tmp), "%Y%m%d-%H%M%S", localtime(&t));
if (s_info_logfile_.empty()) {
std::string pid;
std::string proc_name;
get_curr_proc_info(&pid, &proc_name);
s_debug_logfile_ =
std::string("log." + proc_name + ".log.DEBUG." + tmp + "." + pid);
s_info_logfile_ =
std::string("log." + proc_name + ".log.INFO." + tmp + "." + pid);
s_warning_logfile_ =
std::string("log." + proc_name + ".log.WARNING." + tmp + "." + pid);
s_error_logfile_ =
std::string("log." + proc_name + ".log.ERROR." + tmp + "." + pid);
s_fatal_logfile_ =
std::string("log." + proc_name + ".log.FATAL." + tmp + "." + pid);
}
thread_local static std::ofstream ofs;
if (level_ == DEBUG) {
ofs.open(s_debug_logfile_.c_str(), std::ios::out | std::ios::app);
} else if (level_ == INFO) {
ofs.open(s_info_logfile_.c_str(), std::ios::out | std::ios::app);
} else if (level_ == WARNING) {
ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app);
} else if (level_ == ERROR) {
ofs.open(s_error_logfile_.c_str(), std::ios::out | std::ios::app);
} else {
ofs.open(s_fatal_logfile_.c_str(), std::ios::out | std::ios::app);
}
stream_ = &ofs;
stream() << tmp << " " << file << " line " << line << "; ";
stream() << std::flush;
}
} // namespace log
} // namespace ppspeech

@ -0,0 +1,173 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from https://github.com/Dounm/dlog
// modified form
// https://android.googlesource.com/platform/art/+/806defa/src/logging.h
#pragma once
#include <stdlib.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <mutex>
#include <sstream>
#include <string>
#include <thread>
#include "base/common.h"
#include "base/macros.h"
#ifndef WITH_GLOG
#include "base/glog_utils.h"
#endif
DECLARE_int32(logtostderr);
namespace ppspeech {
namespace log {
enum Severity {
DEBUG,
INFO,
WARNING,
ERROR,
FATAL,
NUM_SEVERITIES,
};
class LogMessage {
public:
static void get_curr_proc_info(std::string* pid, std::string* proc_name);
LogMessage(const char* file,
int line,
Severity level,
bool verbose,
bool out_to_file = false);
~LogMessage();
std::ostream& stream() { return verbose_ ? *stream_ : *nullstream(); }
private:
void init(const char* file, int line);
std::ostream* nullstream();
private:
std::ostream* stream_;
std::ostream* null_stream_;
Severity level_;
bool verbose_;
bool out_to_file_;
static std::mutex lock_; // stream write lock
static std::string s_debug_logfile_;
static std::string s_info_logfile_;
static std::string s_warning_logfile_;
static std::string s_error_logfile_;
static std::string s_fatal_logfile_;
DISALLOW_COPY_AND_ASSIGN(LogMessage);
};
} // namespace log
} // namespace ppspeech
#ifndef PPS_DEBUG
#define DLOG_INFO \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, false)
#define DLOG_WARNING \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, false)
#define DLOG_ERROR \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, false)
#define DLOG_FATAL \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, false)
#else
#define DLOG_INFO \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true)
#define DLOG_WARNING \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true)
#define DLOG_ERROR \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true)
#define DLOG_FATAL \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true)
#endif
#define LOG_INFO \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true)
#define LOG_WARNING \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true)
#define LOG_ERROR \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true)
#define LOG_FATAL \
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true)
#define LOG_0 LOG_DEBUG
#define LOG_1 LOG_INFO
#define LOG_2 LOG_WARNING
#define LOG_3 LOG_ERROR
#define LOG_4 LOG_FATAL
#define LOG(level) LOG_##level.stream()
#define DLOG(level) DLOG_##level.stream()
#define VLOG(verboselevel) LOG(verboselevel)
#define CHECK(exp) \
ppspeech::log::LogMessage( \
__FILE__, __LINE__, ppspeech::log::FATAL, !(exp)) \
.stream() \
<< "Check Failed: " #exp
#define CHECK_EQ(x, y) CHECK((x) == (y))
#define CHECK_NE(x, y) CHECK((x) != (y))
#define CHECK_LE(x, y) CHECK((x) <= (y))
#define CHECK_LT(x, y) CHECK((x) < (y))
#define CHECK_GE(x, y) CHECK((x) >= (y))
#define CHECK_GT(x, y) CHECK((x) > (y))
#ifdef PPS_DEBUG
#define DCHECK(x) CHECK(x)
#define DCHECK_EQ(x, y) CHECK_EQ(x, y)
#define DCHECK_NE(x, y) CHECK_NE(x, y)
#define DCHECK_LE(x, y) CHECK_LE(x, y)
#define DCHECK_LT(x, y) CHECK_LT(x, y)
#define DCHECK_GE(x, y) CHECK_GE(x, y)
#define DCHECK_GT(x, y) CHECK_GT(x, y)
#else
#define DCHECK(condition) \
while (false) CHECK(condition)
#define DCHECK_EQ(val1, val2) \
while (false) CHECK_EQ(val1, val2)
#define DCHECK_NE(val1, val2) \
while (false) CHECK_NE(val1, val2)
#define DCHECK_LE(val1, val2) \
while (false) CHECK_LE(val1, val2)
#define DCHECK_LT(val1, val2) \
while (false) CHECK_LT(val1, val2)
#define DCHECK_GE(val1, val2) \
while (false) CHECK_GE(val1, val2)
#define DCHECK_GT(val1, val2) \
while (false) CHECK_GT(val1, val2)
#define DCHECK_STREQ(str1, str2) \
while (false) CHECK_STREQ(str1, str2)
#endif

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save