commit
7cab869d63
@ -0,0 +1,7 @@
|
|||||||
|
engine/common/base/flags.h
|
||||||
|
engine/common/base/log.h
|
||||||
|
|
||||||
|
tools/valgrind*
|
||||||
|
*log
|
||||||
|
fc_patch/*
|
||||||
|
test
|
@ -0,0 +1,211 @@
|
|||||||
|
# >=3.17 support -DCMAKE_FIND_DEBUG_MODE=ON
|
||||||
|
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
|
||||||
|
|
||||||
|
set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0077.cmake")
|
||||||
|
|
||||||
|
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||||
|
|
||||||
|
include(system)
|
||||||
|
|
||||||
|
project(paddlespeech VERSION 0.1)
|
||||||
|
|
||||||
|
set(PPS_VERSION_MAJOR 1)
|
||||||
|
set(PPS_VERSION_MINOR 0)
|
||||||
|
set(PPS_VERSION_PATCH 0)
|
||||||
|
set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}")
|
||||||
|
|
||||||
|
# compiler option
|
||||||
|
# Keep the same with openfst, -fPIC or -fpic
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl")
|
||||||
|
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
|
||||||
|
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")
|
||||||
|
|
||||||
|
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||||
|
set(CMAKE_FIND_DEBUG_MODE OFF)
|
||||||
|
set(PPS_CXX_STANDARD 14)
|
||||||
|
|
||||||
|
# set std-14
|
||||||
|
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
|
||||||
|
|
||||||
|
# Ninja Generator will set CMAKE_BUILD_TYPE to Debug
|
||||||
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# find_* e.g. find_library work when Cross-Compiling
|
||||||
|
if(ANDROID)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(BUILD_IN_MACOS)
|
||||||
|
add_definitions("-DOS_MACOSX")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# install dir into `build/install`
|
||||||
|
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install)
|
||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
include(ExternalProject)
|
||||||
|
|
||||||
|
# fc_patch dir
|
||||||
|
set(FETCHCONTENT_QUIET off)
|
||||||
|
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
|
||||||
|
set(FETCHCONTENT_BASE_DIR ${fc_patch})
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Option Configurations
|
||||||
|
###############################################################################
|
||||||
|
# https://github.com/google/brotli/pull/655
|
||||||
|
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
|
||||||
|
|
||||||
|
option(WITH_PPS_DEBUG "debug option" OFF)
|
||||||
|
if (WITH_PPS_DEBUG)
|
||||||
|
add_definitions("-DPPS_DEBUG")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
option(WITH_ASR "build asr" ON)
|
||||||
|
option(WITH_CLS "build cls" ON)
|
||||||
|
option(WITH_VAD "build vad" ON)
|
||||||
|
|
||||||
|
option(WITH_GPU "NNet using GPU." OFF)
|
||||||
|
|
||||||
|
option(WITH_PROFILING "enable c++ profling" OFF)
|
||||||
|
option(WITH_TESTING "unit test" ON)
|
||||||
|
|
||||||
|
option(WITH_ONNX "u2 support onnx runtime" OFF)
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Include Third Party
|
||||||
|
###############################################################################
|
||||||
|
include(gflags)
|
||||||
|
|
||||||
|
include(glog)
|
||||||
|
|
||||||
|
include(pybind)
|
||||||
|
|
||||||
|
#onnx
|
||||||
|
if(WITH_ONNX)
|
||||||
|
add_definitions(-DUSE_ONNX)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# gtest
|
||||||
|
if(WITH_TESTING)
|
||||||
|
include(gtest) # download, build, install gtest
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# fastdeploy
|
||||||
|
include(fastdeploy)
|
||||||
|
|
||||||
|
if(WITH_ASR)
|
||||||
|
# openfst
|
||||||
|
include(openfst)
|
||||||
|
add_dependencies(openfst gflags extern_glog)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Find Package
|
||||||
|
###############################################################################
|
||||||
|
# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
if(WITH_ASR)
|
||||||
|
# https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3
|
||||||
|
find_package(Python3 COMPONENTS Interpreter Development)
|
||||||
|
find_package(pybind11 CONFIG)
|
||||||
|
|
||||||
|
if(Python3_FOUND)
|
||||||
|
message(STATUS "Python3_FOUND = ${Python3_FOUND}")
|
||||||
|
message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
|
||||||
|
message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}")
|
||||||
|
message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}")
|
||||||
|
message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}")
|
||||||
|
set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE)
|
||||||
|
set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
|
||||||
|
message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}")
|
||||||
|
include_directories(${PYTHON_INCLUDE_DIR})
|
||||||
|
|
||||||
|
if(pybind11_FOUND)
|
||||||
|
message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}")
|
||||||
|
message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}")
|
||||||
|
message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
# paddle libpaddle.so
|
||||||
|
# paddle include and link option
|
||||||
|
# -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so
|
||||||
|
set(EXECUTE_COMMAND "import os"
|
||||||
|
"import paddle"
|
||||||
|
"include_dir = paddle.sysconfig.get_include()"
|
||||||
|
"paddle_dir=os.path.split(include_dir)[0]"
|
||||||
|
"libs_dir=os.path.join(paddle_dir, 'libs')"
|
||||||
|
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
|
||||||
|
"out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir])"
|
||||||
|
"out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out)"
|
||||||
|
)
|
||||||
|
execute_process(
|
||||||
|
COMMAND python -c "${EXECUTE_COMMAND}"
|
||||||
|
OUTPUT_VARIABLE PADDLE_LINK_FLAGS
|
||||||
|
RESULT_VARIABLE SUCESS)
|
||||||
|
|
||||||
|
message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
|
||||||
|
string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
|
||||||
|
|
||||||
|
# paddle compile option
|
||||||
|
# -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include
|
||||||
|
set(EXECUTE_COMMAND "import paddle"
|
||||||
|
"include_dir = paddle.sysconfig.get_include()"
|
||||||
|
"print(f\"-I{include_dir}\")"
|
||||||
|
)
|
||||||
|
execute_process(
|
||||||
|
COMMAND python -c "${EXECUTE_COMMAND}"
|
||||||
|
OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
|
||||||
|
message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
|
||||||
|
string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
|
||||||
|
|
||||||
|
# for LD_LIBRARY_PATH
|
||||||
|
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
|
||||||
|
set(EXECUTE_COMMAND "import os"
|
||||||
|
"import paddle"
|
||||||
|
"include_dir=paddle.sysconfig.get_include()"
|
||||||
|
"paddle_dir=os.path.split(include_dir)[0]"
|
||||||
|
"libs_dir=os.path.join(paddle_dir, 'libs')"
|
||||||
|
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
|
||||||
|
"out=':'.join([libs_dir, fluid_dir]); print(out)"
|
||||||
|
)
|
||||||
|
execute_process(
|
||||||
|
COMMAND python -c "${EXECUTE_COMMAND}"
|
||||||
|
OUTPUT_VARIABLE PADDLE_LIB_DIRS)
|
||||||
|
message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include(summary)
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Add local library
|
||||||
|
###############################################################################
|
||||||
|
set(ENGINE_ROOT ${CMAKE_SOURCE_DIR}/engine)
|
||||||
|
|
||||||
|
add_subdirectory(engine)
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# CPack library
|
||||||
|
###############################################################################
|
||||||
|
# build a CPack driven installer package
|
||||||
|
include (InstallRequiredSystemLibraries)
|
||||||
|
set(CPACK_PACKAGE_NAME "paddlespeech_library")
|
||||||
|
set(CPACK_PACKAGE_VENDOR "paddlespeech")
|
||||||
|
set(CPACK_PACKAGE_VERSION_MAJOR 1)
|
||||||
|
set(CPACK_PACKAGE_VERSION_MINOR 0)
|
||||||
|
set(CPACK_PACKAGE_VERSION_PATCH 0)
|
||||||
|
set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library")
|
||||||
|
set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com")
|
||||||
|
set(CPACK_SOURCE_GENERATOR "TGZ")
|
||||||
|
include (CPack)
|
@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -xe
|
||||||
|
|
||||||
|
BUILD_ROOT=build/Linux
|
||||||
|
BUILD_DIR=${BUILD_ROOT}/x86_64
|
||||||
|
|
||||||
|
mkdir -p ${BUILD_DIR}
|
||||||
|
|
||||||
|
BUILD_TYPE=Release
|
||||||
|
#BUILD_TYPE=Debug
|
||||||
|
BUILD_SO=OFF
|
||||||
|
BUILD_ONNX=ON
|
||||||
|
BUILD_ASR=ON
|
||||||
|
BUILD_CLS=ON
|
||||||
|
BUILD_VAD=ON
|
||||||
|
PPS_DEBUG=OFF
|
||||||
|
FASTDEPLOY_INSTALL_DIR=""
|
||||||
|
|
||||||
|
# the build script had verified in the paddlepaddle docker image.
|
||||||
|
# please follow the instruction below to install PaddlePaddle image.
|
||||||
|
# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
|
||||||
|
#cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install
|
||||||
|
cmake -B ${BUILD_DIR} \
|
||||||
|
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
|
||||||
|
-DBUILD_SHARED_LIBS=${BUILD_SO} \
|
||||||
|
-DWITH_ONNX=${BUILD_ONNX} \
|
||||||
|
-DWITH_ASR=${BUILD_ASR} \
|
||||||
|
-DWITH_CLS=${BUILD_CLS} \
|
||||||
|
-DWITH_VAD=${BUILD_VAD} \
|
||||||
|
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
|
||||||
|
-DWITH_PPS_DEBUG=${PPS_DEBUG}
|
||||||
|
|
||||||
|
cmake --build ${BUILD_DIR} -j
|
@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
ANDROID_NDK=/mnt/masimeng/workspace/software/android-ndk-r25b/
|
||||||
|
|
||||||
|
# Setting up Android toolchanin
|
||||||
|
ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a'
|
||||||
|
ANDROID_PLATFORM="android-21" # API >= 21
|
||||||
|
ANDROID_STL=c++_shared # 'c++_shared', 'c++_static'
|
||||||
|
ANDROID_TOOLCHAIN=clang # 'clang' only
|
||||||
|
TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
|
||||||
|
|
||||||
|
# Create build directory
|
||||||
|
BUILD_ROOT=build/Android
|
||||||
|
BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21
|
||||||
|
FASTDEPLOY_INSTALL_DIR="/mnt/masimeng/workspace/FastDeploy/build/Android/arm64-v8a-api-21/install"
|
||||||
|
|
||||||
|
mkdir -p ${BUILD_DIR}
|
||||||
|
cd ${BUILD_DIR}
|
||||||
|
|
||||||
|
# CMake configuration with Android toolchain
|
||||||
|
cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \
|
||||||
|
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
|
-DANDROID_ABI=${ANDROID_ABI} \
|
||||||
|
-DANDROID_NDK=${ANDROID_NDK} \
|
||||||
|
-DANDROID_PLATFORM=${ANDROID_PLATFORM} \
|
||||||
|
-DANDROID_STL=${ANDROID_STL} \
|
||||||
|
-DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \
|
||||||
|
-DBUILD_SHARED_LIBS=OFF \
|
||||||
|
-DWITH_ASR=OFF \
|
||||||
|
-DWITH_CLS=OFF \
|
||||||
|
-DWITH_VAD=ON \
|
||||||
|
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
|
||||||
|
-DCMAKE_FIND_DEBUG_MODE=OFF \
|
||||||
|
-Wno-dev ../../..
|
||||||
|
|
||||||
|
# Build FastDeploy Android C++ SDK
|
||||||
|
make
|
@ -0,0 +1,91 @@
|
|||||||
|
# https://www.jianshu.com/p/33672fb819f5
|
||||||
|
|
||||||
|
PATH="/Applications/CMake.app/Contents/bin":"$PATH"
|
||||||
|
tools_dir=$1
|
||||||
|
ios_toolchain_cmake=${tools_dir}/"/ios-cmake-4.2.0/ios.toolchain.cmake"
|
||||||
|
fastdeploy_dir=${tools_dir}"/fastdeploy-ort-mac-build/"
|
||||||
|
build_targets=("OS64")
|
||||||
|
build_type_array=("Release")
|
||||||
|
|
||||||
|
#static_name="libocr"
|
||||||
|
#lib_name="libocr"
|
||||||
|
|
||||||
|
# Switch to workpath
|
||||||
|
current_path=`cd $(dirname $0);pwd`
|
||||||
|
work_path=${current_path}/
|
||||||
|
build_path=${current_path}/build/
|
||||||
|
output_path=${current_path}/output/
|
||||||
|
cd ${work_path}
|
||||||
|
|
||||||
|
# Clean
|
||||||
|
rm -rf ${build_path}
|
||||||
|
rm -rf ${output_path}
|
||||||
|
|
||||||
|
if [ "$1"x = "clean"x ]; then
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build Every Target
|
||||||
|
for target in "${build_targets[@]}"
|
||||||
|
do
|
||||||
|
for build_type in "${build_type_array[@]}"
|
||||||
|
do
|
||||||
|
echo -e "\033[1;36;40mBuilding ${build_type} ${target} ... \033[0m"
|
||||||
|
target_build_path=${build_path}/${target}/${build_type}/
|
||||||
|
mkdir -p ${target_build_path}
|
||||||
|
|
||||||
|
cd ${target_build_path}
|
||||||
|
if [ $? -ne 0 ];then
|
||||||
|
echo -e "\033[1;31;40mcd ${target_build_path} failed \033[0m"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${target} == "OS64" ];then
|
||||||
|
fastdeploy_install_dir=${fastdeploy_dir}/arm64
|
||||||
|
else
|
||||||
|
fastdeploy_install_dir=""
|
||||||
|
echo "fastdeploy_install_dir is null"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cmake -DCMAKE_TOOLCHAIN_FILE=${ios_toolchain_cmake} \
|
||||||
|
-DBUILD_IN_MACOS=ON \
|
||||||
|
-DBUILD_SHARED_LIBS=OFF \
|
||||||
|
-DWITH_ASR=OFF \
|
||||||
|
-DWITH_CLS=OFF \
|
||||||
|
-DWITH_VAD=ON \
|
||||||
|
-DFASTDEPLOY_INSTALL_DIR=${fastdeploy_install_dir} \
|
||||||
|
-DPLATFORM=${target} ../../../
|
||||||
|
|
||||||
|
cmake --build . --config ${build_type}
|
||||||
|
|
||||||
|
mkdir output
|
||||||
|
cp engine/vad/interface/libpps_vad_interface.a output
|
||||||
|
cp engine/vad/interface/vad_interface_main.app/vad_interface_main output
|
||||||
|
cp ${fastdeploy_install_dir}/lib/libfastdeploy.dylib output
|
||||||
|
cp ${fastdeploy_install_dir}/third_libs/install/onnxruntime/lib/libonnxruntime.dylib output
|
||||||
|
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
## combine all ios libraries
|
||||||
|
#DEVROOT=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/
|
||||||
|
#LIPO_TOOL=${DEVROOT}/usr/bin/lipo
|
||||||
|
#LIBRARY_PATH=${build_path}
|
||||||
|
#LIBRARY_OUTPUT_PATH=${output_path}/IOS
|
||||||
|
#mkdir -p ${LIBRARY_OUTPUT_PATH}
|
||||||
|
#
|
||||||
|
#${LIPO_TOOL} \
|
||||||
|
# -arch i386 ${LIBRARY_PATH}/ios_x86/Release/${lib_name}.a \
|
||||||
|
# -arch x86_64 ${LIBRARY_PATH}/ios_x86_64/Release/${lib_name}.a \
|
||||||
|
# -arch armv7 ${LIBRARY_PATH}/ios_armv7/Release/${lib_name}.a \
|
||||||
|
# -arch armv7s ${LIBRARY_PATH}/ios_armv7s/Release/${lib_name}.a \
|
||||||
|
# -arch arm64 ${LIBRARY_PATH}/ios_armv8/Release/${lib_name}.a \
|
||||||
|
# -output ${LIBRARY_OUTPUT_PATH}/${lib_name}.a -create
|
||||||
|
#
|
||||||
|
#cp ${work_path}/lib/houyi/lib/ios/libhouyi_score.a ${LIBRARY_OUTPUT_PATH}/
|
||||||
|
#cp ${work_path}/interface/ocr-interface.h ${output_path}
|
||||||
|
#cp ${work_path}/version/release.v ${output_path}
|
||||||
|
#
|
||||||
|
#echo -e "\033[1;36;40mBuild All Target Success At:\n${output_path}\033[0m"
|
||||||
|
#exit 0
|
@ -0,0 +1 @@
|
|||||||
|
cmake_policy(SET CMP0077 NEW)
|
@ -0,0 +1,116 @@
|
|||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
set(EXTERNAL_PROJECT_LOG_ARGS
|
||||||
|
LOG_DOWNLOAD 1 # Wrap download in script to log output
|
||||||
|
LOG_UPDATE 1 # Wrap update in script to log output
|
||||||
|
LOG_PATCH 1
|
||||||
|
LOG_CONFIGURE 1# Wrap configure in script to log output
|
||||||
|
LOG_BUILD 1 # Wrap build in script to log output
|
||||||
|
LOG_INSTALL 1
|
||||||
|
LOG_TEST 1 # Wrap test in script to log output
|
||||||
|
LOG_MERGED_STDOUTERR 1
|
||||||
|
LOG_OUTPUT_ON_FAILURE 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if(NOT FASTDEPLOY_INSTALL_DIR)
|
||||||
|
if(ANDROID)
|
||||||
|
FetchContent_Declare(
|
||||||
|
fastdeploy
|
||||||
|
URL https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz
|
||||||
|
URL_HASH MD5=2a15301158e9eb157a4f11283689e7ba
|
||||||
|
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||||
|
)
|
||||||
|
add_definitions("-DUSE_PADDLE_LITE_BAKEND")
|
||||||
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
|
||||||
|
else() # Linux
|
||||||
|
FetchContent_Declare(
|
||||||
|
fastdeploy
|
||||||
|
URL https://paddlespeech.bj.bcebos.com/speechx/fastdeploy/fastdeploy-1.0.5-x86_64-onnx.tar.gz
|
||||||
|
URL_HASH MD5=33900d986ea71aa78635e52f0733227c
|
||||||
|
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||||
|
)
|
||||||
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
FetchContent_MakeAvailable(fastdeploy)
|
||||||
|
|
||||||
|
set(FASTDEPLOY_INSTALL_DIR ${fc_patch}/fastdeploy-src)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||||
|
|
||||||
|
# fix compiler flags conflict, since fastdeploy using c++11 for project
|
||||||
|
# this line must after `include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)`
|
||||||
|
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
|
||||||
|
|
||||||
|
include_directories(${FASTDEPLOY_INCS})
|
||||||
|
|
||||||
|
# install fastdeploy and dependents lib
|
||||||
|
# install_fastdeploy_libraries(${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
|
||||||
|
# No dynamic libs need to install while using
|
||||||
|
# FastDeploy static lib.
|
||||||
|
if(ANDROID AND WITH_ANDROID_STATIC_LIB)
|
||||||
|
return()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(DYN_LIB_SUFFIX "*.so*")
|
||||||
|
if(WIN32)
|
||||||
|
set(DYN_LIB_SUFFIX "*.dll")
|
||||||
|
elseif(APPLE)
|
||||||
|
set(DYN_LIB_SUFFIX "*.dylib*")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(FastDeploy_DIR)
|
||||||
|
set(DYN_SEARCH_DIR ${FastDeploy_DIR})
|
||||||
|
elseif(FASTDEPLOY_INSTALL_DIR)
|
||||||
|
set(DYN_SEARCH_DIR ${FASTDEPLOY_INSTALL_DIR})
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Please set FastDeploy_DIR/FASTDEPLOY_INSTALL_DIR before call install_fastdeploy_libraries.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
file(GLOB_RECURSE ALL_NEED_DYN_LIBS ${DYN_SEARCH_DIR}/lib/${DYN_LIB_SUFFIX})
|
||||||
|
file(GLOB_RECURSE ALL_DEPS_DYN_LIBS ${DYN_SEARCH_DIR}/third_libs/${DYN_LIB_SUFFIX})
|
||||||
|
|
||||||
|
if(ENABLE_VISION)
|
||||||
|
# OpenCV
|
||||||
|
if(ANDROID)
|
||||||
|
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${DYN_LIB_SUFFIX})
|
||||||
|
else()
|
||||||
|
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_DIR}/../../${DYN_LIB_SUFFIX})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_OPENCV_DYN_LIBS})
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/x64/vc15/bin/${DYN_LIB_SUFFIX})
|
||||||
|
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
|
||||||
|
elseif(ANDROID AND (NOT WITH_ANDROID_OPENCV_STATIC))
|
||||||
|
file(GLOB OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${ANDROID_ABI}/${DYN_LIB_SUFFIX})
|
||||||
|
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
|
||||||
|
else() # linux/mac
|
||||||
|
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/lib/${DYN_LIB_SUFFIX})
|
||||||
|
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# FlyCV
|
||||||
|
if(ENABLE_FLYCV)
|
||||||
|
file(GLOB_RECURSE ALL_FLYCV_DYN_LIBS ${FLYCV_LIB_DIR}/${DYN_LIB_SUFFIX})
|
||||||
|
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_FLYCV_DYN_LIBS})
|
||||||
|
if(ANDROID AND (NOT WITH_ANDROID_FLYCV_STATIC))
|
||||||
|
install(FILES ${ALL_FLYCV_DYN_LIBS} DESTINATION lib)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(ENABLE_OPENVINO_BACKEND)
|
||||||
|
# need plugins.xml for openvino backend
|
||||||
|
set(OPENVINO_RUNTIME_BIN_DIR ${OPENVINO_DIR}/bin)
|
||||||
|
file(GLOB OPENVINO_PLUGIN_XML ${OPENVINO_RUNTIME_BIN_DIR}/*.xml)
|
||||||
|
install(FILES ${OPENVINO_PLUGIN_XML} DESTINATION lib)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Install other libraries
|
||||||
|
install(FILES ${ALL_NEED_DYN_LIBS} DESTINATION lib)
|
||||||
|
install(FILES ${ALL_DEPS_DYN_LIBS} DESTINATION lib)
|
@ -0,0 +1,14 @@
|
|||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
FetchContent_Declare(
|
||||||
|
gflags
|
||||||
|
URL https://paddleaudio.bj.bcebos.com/build/gflag-2.2.2.zip
|
||||||
|
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(gflags)
|
||||||
|
|
||||||
|
# openfst need
|
||||||
|
include_directories(${gflags_BINARY_DIR}/include)
|
||||||
|
link_directories(${gflags_BINARY_DIR})
|
||||||
|
|
||||||
|
#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib)
|
@ -0,0 +1,35 @@
|
|||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
else() # UNIX
|
||||||
|
add_definitions(-DWITH_GLOG)
|
||||||
|
FetchContent_Declare(
|
||||||
|
glog
|
||||||
|
URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip
|
||||||
|
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
|
||||||
|
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
|
||||||
|
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
|
||||||
|
-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
|
||||||
|
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
|
||||||
|
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
|
||||||
|
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
|
||||||
|
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
|
||||||
|
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
|
||||||
|
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||||
|
-DWITH_GFLAGS=OFF
|
||||||
|
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||||
|
${EXTERNAL_OPTIONAL_ARGS}
|
||||||
|
)
|
||||||
|
set(BUILD_TESTING OFF)
|
||||||
|
FetchContent_MakeAvailable(glog)
|
||||||
|
include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
add_library(extern_glog INTERFACE)
|
||||||
|
add_dependencies(extern_glog gflags)
|
||||||
|
else() # UNIX
|
||||||
|
add_library(extern_glog ALIAS glog)
|
||||||
|
add_dependencies(glog gflags)
|
||||||
|
endif()
|
@ -0,0 +1,27 @@
|
|||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
else() # UNIX
|
||||||
|
FetchContent_Declare(
|
||||||
|
gtest
|
||||||
|
URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip
|
||||||
|
URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(gtest)
|
||||||
|
|
||||||
|
include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
add_library(extern_gtest INTERFACE)
|
||||||
|
else() # UNIX
|
||||||
|
add_dependencies(gtest gflags extern_glog)
|
||||||
|
add_library(extern_gtest ALIAS gtest)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(WITH_TESTING)
|
||||||
|
enable_testing()
|
||||||
|
endif()
|
@ -0,0 +1,42 @@
|
|||||||
|
#the pybind11 is from:https://github.com/pybind/pybind11
|
||||||
|
# Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>, All rights reserved.
|
||||||
|
|
||||||
|
SET(PYBIND_ZIP "v2.10.0.zip")
|
||||||
|
SET(LOCAL_PYBIND_ZIP ${FETCHCONTENT_BASE_DIR}/${PYBIND_ZIP})
|
||||||
|
SET(PYBIND_SRC ${FETCHCONTENT_BASE_DIR}/pybind11)
|
||||||
|
SET(DOWNLOAD_URL "https://paddleaudio.bj.bcebos.com/build/v2.10.0.zip")
|
||||||
|
SET(PYBIND_TIMEOUT 600 CACHE STRING "Timeout in seconds when downloading pybind.")
|
||||||
|
|
||||||
|
IF(NOT EXISTS ${LOCAL_PYBIND_ZIP})
|
||||||
|
FILE(DOWNLOAD ${DOWNLOAD_URL}
|
||||||
|
${LOCAL_PYBIND_ZIP}
|
||||||
|
TIMEOUT ${PYBIND_TIMEOUT}
|
||||||
|
STATUS ERR
|
||||||
|
SHOW_PROGRESS
|
||||||
|
)
|
||||||
|
|
||||||
|
IF(ERR EQUAL 0)
|
||||||
|
MESSAGE(STATUS "download pybind success")
|
||||||
|
ELSE()
|
||||||
|
MESSAGE(FATAL_ERROR "download pybind fail")
|
||||||
|
ENDIF()
|
||||||
|
ENDIF()
|
||||||
|
|
||||||
|
IF(NOT EXISTS ${PYBIND_SRC})
|
||||||
|
EXECUTE_PROCESS(
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E tar xfz ${LOCAL_PYBIND_ZIP}
|
||||||
|
WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR}
|
||||||
|
RESULT_VARIABLE tar_result
|
||||||
|
)
|
||||||
|
|
||||||
|
file(RENAME ${FETCHCONTENT_BASE_DIR}/pybind11-2.10.0 ${PYBIND_SRC})
|
||||||
|
|
||||||
|
IF (tar_result MATCHES 0)
|
||||||
|
MESSAGE(STATUS "unzip pybind success")
|
||||||
|
ELSE()
|
||||||
|
MESSAGE(FATAL_ERROR "unzip pybind fail")
|
||||||
|
ENDIF()
|
||||||
|
|
||||||
|
ENDIF()
|
||||||
|
|
||||||
|
include_directories(${PYBIND_SRC}/include)
|
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
function(pps_summary)
|
||||||
|
message(STATUS "")
|
||||||
|
message(STATUS "*************PaddleSpeech Building Summary**********")
|
||||||
|
message(STATUS " PPS_VERSION : ${PPS_VERSION}")
|
||||||
|
message(STATUS " CMake version : ${CMAKE_VERSION}")
|
||||||
|
message(STATUS " CMake command : ${CMAKE_COMMAND}")
|
||||||
|
message(STATUS " UNIX : ${UNIX}")
|
||||||
|
message(STATUS " ANDROID : ${ANDROID}")
|
||||||
|
message(STATUS " System : ${CMAKE_SYSTEM_NAME}")
|
||||||
|
message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}")
|
||||||
|
message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}")
|
||||||
|
message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}")
|
||||||
|
message(STATUS " Build type : ${CMAKE_BUILD_TYPE}")
|
||||||
|
message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}")
|
||||||
|
get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS)
|
||||||
|
message(STATUS " Compile definitions : ${tmp}")
|
||||||
|
message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}")
|
||||||
|
message(STATUS " CMAKE_CURRENT_BINARY_DIR : ${CMAKE_CURRENT_BINARY_DIR}")
|
||||||
|
message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}")
|
||||||
|
message(STATUS " CMAKE_INSTALL_LIBDIR : ${CMAKE_INSTALL_LIBDIR}")
|
||||||
|
message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}")
|
||||||
|
message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}")
|
||||||
|
message(STATUS "")
|
||||||
|
|
||||||
|
message(STATUS " WITH_ASR : ${WITH_ASR}")
|
||||||
|
message(STATUS " WITH_CLS : ${WITH_CLS}")
|
||||||
|
message(STATUS " WITH_VAD : ${WITH_VAD}")
|
||||||
|
message(STATUS " WITH_GPU : ${WITH_GPU}")
|
||||||
|
message(STATUS " WITH_TESTING : ${WITH_TESTING}")
|
||||||
|
message(STATUS " WITH_PROFILING : ${WITH_PROFILING}")
|
||||||
|
message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}")
|
||||||
|
message(STATUS " FASTDEPLOY_INCS : ${FASTDEPLOY_INCS}")
|
||||||
|
message(STATUS " FASTDEPLOY_LIBS : ${FASTDEPLOY_LIBS}")
|
||||||
|
if(WITH_GPU)
|
||||||
|
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
message(STATUS " ANDROID_ABI : ${ANDROID_ABI}")
|
||||||
|
message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}")
|
||||||
|
message(STATUS " ANDROID_NDK : ${ANDROID_NDK}")
|
||||||
|
message(STATUS " ANDROID_NDK_VERSION : ${CMAKE_ANDROID_NDK_VERSION}")
|
||||||
|
endif()
|
||||||
|
if (WITH_ASR)
|
||||||
|
message(STATUS " Python executable : ${PYTHON_EXECUTABLE}")
|
||||||
|
message(STATUS " Python includes : ${PYTHON_INCLUDE_DIR}")
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
pps_summary()
|
@ -0,0 +1,106 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Detects the OS and sets appropriate variables.
|
||||||
|
# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is
|
||||||
|
# building for, but the host processor name like centos is necessary
|
||||||
|
# in some scenes to distinguish system for customization.
|
||||||
|
#
|
||||||
|
# for instance, protobuf libs path is <install_dir>/lib64
|
||||||
|
# on CentOS, but <install_dir>/lib on other systems.
|
||||||
|
|
||||||
|
if(UNIX AND NOT APPLE)
|
||||||
|
# except apple from nix*Os family
|
||||||
|
set(LINUX TRUE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
set(HOST_SYSTEM "win32")
|
||||||
|
else()
|
||||||
|
if(APPLE)
|
||||||
|
set(HOST_SYSTEM "macosx")
|
||||||
|
exec_program(
|
||||||
|
sw_vers ARGS
|
||||||
|
-productVersion
|
||||||
|
OUTPUT_VARIABLE HOST_SYSTEM_VERSION)
|
||||||
|
string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}")
|
||||||
|
if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET})
|
||||||
|
# Set cache variable - end user may change this during ccmake or cmake-gui configure.
|
||||||
|
set(CMAKE_OSX_DEPLOYMENT_TARGET
|
||||||
|
${MACOS_VERSION}
|
||||||
|
CACHE
|
||||||
|
STRING
|
||||||
|
"Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value."
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
|
||||||
|
else()
|
||||||
|
|
||||||
|
if(EXISTS "/etc/issue")
|
||||||
|
file(READ "/etc/issue" LINUX_ISSUE)
|
||||||
|
if(LINUX_ISSUE MATCHES "CentOS")
|
||||||
|
set(HOST_SYSTEM "centos")
|
||||||
|
elseif(LINUX_ISSUE MATCHES "Debian")
|
||||||
|
set(HOST_SYSTEM "debian")
|
||||||
|
elseif(LINUX_ISSUE MATCHES "Ubuntu")
|
||||||
|
set(HOST_SYSTEM "ubuntu")
|
||||||
|
elseif(LINUX_ISSUE MATCHES "Red Hat")
|
||||||
|
set(HOST_SYSTEM "redhat")
|
||||||
|
elseif(LINUX_ISSUE MATCHES "Fedora")
|
||||||
|
set(HOST_SYSTEM "fedora")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION
|
||||||
|
"${LINUX_ISSUE}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(EXISTS "/etc/redhat-release")
|
||||||
|
file(READ "/etc/redhat-release" LINUX_ISSUE)
|
||||||
|
if(LINUX_ISSUE MATCHES "CentOS")
|
||||||
|
set(HOST_SYSTEM "centos")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT HOST_SYSTEM)
|
||||||
|
set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# query number of logical cores
|
||||||
|
cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES)
|
||||||
|
|
||||||
|
mark_as_advanced(HOST_SYSTEM CPU_CORES)
|
||||||
|
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}")
|
||||||
|
message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores")
|
||||||
|
|
||||||
|
# external dependencies log output
|
||||||
|
set(EXTERNAL_PROJECT_LOG_ARGS
|
||||||
|
LOG_DOWNLOAD
|
||||||
|
0 # Wrap download in script to log output
|
||||||
|
LOG_UPDATE
|
||||||
|
1 # Wrap update in script to log output
|
||||||
|
LOG_CONFIGURE
|
||||||
|
1 # Wrap configure in script to log output
|
||||||
|
LOG_BUILD
|
||||||
|
0 # Wrap build in script to log output
|
||||||
|
LOG_TEST
|
||||||
|
1 # Wrap test in script to log output
|
||||||
|
LOG_INSTALL
|
||||||
|
0 # Wrap install in script to log output
|
||||||
|
)
|
@ -0,0 +1,22 @@
|
|||||||
|
project(speechx LANGUAGES CXX)
|
||||||
|
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common)
|
||||||
|
|
||||||
|
add_subdirectory(kaldi)
|
||||||
|
add_subdirectory(common)
|
||||||
|
|
||||||
|
if(WITH_ASR)
|
||||||
|
add_subdirectory(asr)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(WITH_CLS)
|
||||||
|
add_subdirectory(audio_classification)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(WITH_VAD)
|
||||||
|
add_subdirectory(vad)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_subdirectory(codelab)
|
@ -0,0 +1,11 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||||
|
|
||||||
|
project(ASR LANGUAGES CXX)
|
||||||
|
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/server)
|
||||||
|
|
||||||
|
add_subdirectory(decoder)
|
||||||
|
add_subdirectory(recognizer)
|
||||||
|
add_subdirectory(nnet)
|
||||||
|
add_subdirectory(server)
|
@ -0,0 +1,24 @@
|
|||||||
|
set(srcs)
|
||||||
|
list(APPEND srcs
|
||||||
|
ctc_prefix_beam_search_decoder.cc
|
||||||
|
ctc_tlg_decoder.cc
|
||||||
|
)
|
||||||
|
|
||||||
|
add_library(decoder STATIC ${srcs})
|
||||||
|
target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
|
||||||
|
|
||||||
|
# test
|
||||||
|
set(TEST_BINS
|
||||||
|
ctc_prefix_beam_search_decoder_main
|
||||||
|
ctc_tlg_decoder_main
|
||||||
|
)
|
||||||
|
|
||||||
|
foreach(bin_name IN LISTS TEST_BINS)
|
||||||
|
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||||
|
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||||
|
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
|
||||||
|
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
|
||||||
|
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||||
|
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
|
||||||
|
endforeach()
|
||||||
|
|
@ -0,0 +1,21 @@
|
|||||||
|
set(srcs decodable.cc nnet_producer.cc)
|
||||||
|
|
||||||
|
list(APPEND srcs u2_nnet.cc)
|
||||||
|
if(WITH_ONNX)
|
||||||
|
list(APPEND srcs u2_onnx_nnet.cc)
|
||||||
|
endif()
|
||||||
|
add_library(nnet STATIC ${srcs})
|
||||||
|
target_link_libraries(nnet utils)
|
||||||
|
if(WITH_ONNX)
|
||||||
|
target_link_libraries(nnet ${FASTDEPLOY_LIBS})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
|
||||||
|
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||||
|
|
||||||
|
# test bin
|
||||||
|
#set(bin_name u2_nnet_main)
|
||||||
|
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||||
|
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
|
||||||
|
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||||
|
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
|
@ -0,0 +1,99 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "nnet/nnet_producer.h"
|
||||||
|
|
||||||
|
#include "matrix/kaldi-matrix.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
|
||||||
|
std::shared_ptr<FrontendInterface> frontend,
|
||||||
|
float blank_threshold)
|
||||||
|
: nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) {
|
||||||
|
Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
|
||||||
|
frontend_->Accept(inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnetProducer::Acceptlikelihood(
|
||||||
|
const kaldi::Matrix<BaseFloat>& likelihood) {
|
||||||
|
std::vector<BaseFloat> prob;
|
||||||
|
prob.resize(likelihood.NumCols());
|
||||||
|
for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) {
|
||||||
|
for (size_t col = 0; col < likelihood.NumCols(); ++col) {
|
||||||
|
prob[col] = likelihood(idx, col);
|
||||||
|
}
|
||||||
|
cache_.push_back(prob);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
|
||||||
|
bool flag = cache_.pop(nnet_prob);
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NnetProducer::Compute() {
|
||||||
|
vector<BaseFloat> features;
|
||||||
|
if (frontend_ == NULL || frontend_->Read(&features) == false) {
|
||||||
|
// no feat or frontend_ not init.
|
||||||
|
if (frontend_->IsFinished() == true) {
|
||||||
|
finished_ = true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
CHECK_GE(frontend_->Dim(), 0);
|
||||||
|
VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
|
||||||
|
|
||||||
|
NnetOut out;
|
||||||
|
nnet_->FeedForward(features, frontend_->Dim(), &out);
|
||||||
|
int32& vocab_dim = out.vocab_dim;
|
||||||
|
size_t nframes = out.logprobs.size() / vocab_dim;
|
||||||
|
VLOG(1) << "Forward out " << nframes << " decoder frames.";
|
||||||
|
for (size_t idx = 0; idx < nframes; ++idx) {
|
||||||
|
std::vector<BaseFloat> logprob(
|
||||||
|
out.logprobs.data() + idx * vocab_dim,
|
||||||
|
out.logprobs.data() + (idx + 1) * vocab_dim);
|
||||||
|
// process blank prob
|
||||||
|
float blank_prob = std::exp(logprob[0]);
|
||||||
|
if (blank_prob > blank_threshold_) {
|
||||||
|
last_frame_logprob_ = logprob;
|
||||||
|
is_last_frame_skip_ = true;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin();
|
||||||
|
if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) {
|
||||||
|
cache_.push_back(last_frame_logprob_);
|
||||||
|
last_max_elem_ = cur_max;
|
||||||
|
}
|
||||||
|
last_max_elem_ = cur_max;
|
||||||
|
is_last_frame_skip_ = false;
|
||||||
|
cache_.push_back(logprob);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnetProducer::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||||
|
float reverse_weight,
|
||||||
|
std::vector<float>* rescoring_score) {
|
||||||
|
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,77 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "base/safe_queue.h"
|
||||||
|
#include "frontend/frontend_itf.h"
|
||||||
|
#include "nnet/nnet_itf.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class NnetProducer {
|
||||||
|
public:
|
||||||
|
explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
|
||||||
|
std::shared_ptr<FrontendInterface> frontend,
|
||||||
|
float blank_threshold);
|
||||||
|
// Feed feats or waves
|
||||||
|
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
|
||||||
|
|
||||||
|
void Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood);
|
||||||
|
|
||||||
|
// nnet
|
||||||
|
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
|
||||||
|
|
||||||
|
bool Empty() const { return cache_.empty(); }
|
||||||
|
|
||||||
|
void SetInputFinished() {
|
||||||
|
LOG(INFO) << "set finished";
|
||||||
|
frontend_->SetFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
// the compute thread exit
|
||||||
|
bool IsFinished() const {
|
||||||
|
return (frontend_->IsFinished() && finished_);
|
||||||
|
}
|
||||||
|
|
||||||
|
~NnetProducer() {}
|
||||||
|
|
||||||
|
void Reset() {
|
||||||
|
if (frontend_ != NULL) frontend_->Reset();
|
||||||
|
if (nnet_ != NULL) nnet_->Reset();
|
||||||
|
cache_.clear();
|
||||||
|
finished_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||||
|
float reverse_weight,
|
||||||
|
std::vector<float>* rescoring_score);
|
||||||
|
|
||||||
|
bool Compute();
|
||||||
|
private:
|
||||||
|
|
||||||
|
std::shared_ptr<FrontendInterface> frontend_;
|
||||||
|
std::shared_ptr<NnetBase> nnet_;
|
||||||
|
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
|
||||||
|
std::vector<BaseFloat> last_frame_logprob_;
|
||||||
|
bool is_last_frame_skip_ = false;
|
||||||
|
int last_max_elem_ = -1;
|
||||||
|
float blank_threshold_ = 0.0;
|
||||||
|
bool finished_;
|
||||||
|
|
||||||
|
DISALLOW_COPY_AND_ASSIGN(NnetProducer);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,145 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef USE_ONNX
|
||||||
|
#include "nnet/u2_nnet.h"
|
||||||
|
#else
|
||||||
|
#include "nnet/u2_onnx_nnet.h"
|
||||||
|
#endif
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "decoder/param.h"
|
||||||
|
#include "frontend/feature_pipeline.h"
|
||||||
|
#include "frontend/wave-reader.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
#include "nnet/decodable.h"
|
||||||
|
#include "nnet/nnet_producer.h"
|
||||||
|
#include "nnet/u2_nnet.h"
|
||||||
|
|
||||||
|
DEFINE_string(wav_rspecifier, "", "test wav rspecifier");
|
||||||
|
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
|
||||||
|
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||||
|
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||||
|
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
using kaldi::Matrix;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::SetUsageMessage("Usage:");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
google::InstallFailureSignalHandler();
|
||||||
|
FLAGS_logtostderr = 1;
|
||||||
|
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
int sample_rate = FLAGS_sample_rate;
|
||||||
|
float streaming_chunk = FLAGS_streaming_chunk;
|
||||||
|
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||||
|
|
||||||
|
CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
|
||||||
|
CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
|
||||||
|
CHECK_GT(FLAGS_model_path.size(), 0);
|
||||||
|
LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier;
|
||||||
|
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
|
||||||
|
LOG(INFO) << "model path: " << FLAGS_model_path;
|
||||||
|
|
||||||
|
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||||
|
FLAGS_wav_rspecifier);
|
||||||
|
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
|
||||||
|
|
||||||
|
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
|
||||||
|
ppspeech::FeaturePipelineOptions feature_opts =
|
||||||
|
ppspeech::FeaturePipelineOptions::InitFromFlags();
|
||||||
|
feature_opts.assembler_opts.fill_zero = false;
|
||||||
|
|
||||||
|
#ifndef USE_ONNX
|
||||||
|
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
|
||||||
|
#else
|
||||||
|
std::shared_ptr<ppspeech::U2OnnxNnet> nnet(new ppspeech::U2OnnxNnet(model_opts));
|
||||||
|
#endif
|
||||||
|
std::shared_ptr<ppspeech::FeaturePipeline> feature_pipeline(
|
||||||
|
new ppspeech::FeaturePipeline(feature_opts));
|
||||||
|
std::shared_ptr<ppspeech::NnetProducer> nnet_producer(
|
||||||
|
new ppspeech::NnetProducer(nnet, feature_pipeline));
|
||||||
|
kaldi::Timer timer;
|
||||||
|
float tot_wav_duration = 0;
|
||||||
|
|
||||||
|
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||||
|
std::string utt = wav_reader.Key();
|
||||||
|
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||||
|
LOG(INFO) << "utt: " << utt;
|
||||||
|
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||||
|
double dur = wave_data.Duration();
|
||||||
|
tot_wav_duration += dur;
|
||||||
|
|
||||||
|
int32 this_channel = 0;
|
||||||
|
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||||
|
this_channel);
|
||||||
|
int tot_samples = waveform.Dim();
|
||||||
|
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||||
|
|
||||||
|
int sample_offset = 0;
|
||||||
|
kaldi::Timer timer;
|
||||||
|
|
||||||
|
while (sample_offset < tot_samples) {
|
||||||
|
int cur_chunk_size =
|
||||||
|
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||||
|
|
||||||
|
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||||
|
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||||
|
wav_chunk[i] = waveform(sample_offset + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
nnet_producer->Accept(wav_chunk);
|
||||||
|
if (cur_chunk_size < chunk_sample_size) {
|
||||||
|
nnet_producer->SetInputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
// no overlap
|
||||||
|
sample_offset += cur_chunk_size;
|
||||||
|
}
|
||||||
|
CHECK(sample_offset == tot_samples);
|
||||||
|
|
||||||
|
std::vector<std::vector<kaldi::BaseFloat>> prob_vec;
|
||||||
|
while (1) {
|
||||||
|
std::vector<kaldi::BaseFloat> logprobs;
|
||||||
|
bool isok = nnet_producer->Read(&logprobs);
|
||||||
|
if (nnet_producer->IsFinished()) break;
|
||||||
|
if (isok == false) continue;
|
||||||
|
prob_vec.push_back(logprobs);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// writer nnet output
|
||||||
|
kaldi::MatrixIndexT nrow = prob_vec.size();
|
||||||
|
kaldi::MatrixIndexT ncol = prob_vec[0].size();
|
||||||
|
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
|
||||||
|
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
|
||||||
|
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
|
||||||
|
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
|
||||||
|
nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nnet_out_writer.Write(utt, nnet_out);
|
||||||
|
}
|
||||||
|
nnet_producer->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
nnet_producer->Wait();
|
||||||
|
double elapsed = timer.Elapsed();
|
||||||
|
LOG(INFO) << "Program cost:" << elapsed << " sec";
|
||||||
|
|
||||||
|
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
|
||||||
|
<< " with errors.";
|
||||||
|
return (num_done != 0 ? 0 : 1);
|
||||||
|
}
|
@ -0,0 +1,414 @@
|
|||||||
|
// Copyright 2022 Horizon Robotics. All Rights Reserved.
|
||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// modified from
|
||||||
|
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.cc
|
||||||
|
|
||||||
|
#include "nnet/u2_onnx_nnet.h"
|
||||||
|
#include "common/base/config.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
void U2OnnxNnet::LoadModel(const std::string& model_dir) {
|
||||||
|
std::string encoder_onnx_path = model_dir + "/encoder.onnx";
|
||||||
|
std::string rescore_onnx_path = model_dir + "/decoder.onnx";
|
||||||
|
std::string ctc_onnx_path = model_dir + "/ctc.onnx";
|
||||||
|
std::string param_path = model_dir + "/param.onnx";
|
||||||
|
// 1. Load sessions
|
||||||
|
try {
|
||||||
|
encoder_ = std::make_shared<fastdeploy::Runtime>();
|
||||||
|
ctc_ = std::make_shared<fastdeploy::Runtime>();
|
||||||
|
rescore_ = std::make_shared<fastdeploy::Runtime>();
|
||||||
|
fastdeploy::RuntimeOption runtime_option;
|
||||||
|
runtime_option.UseOrtBackend();
|
||||||
|
runtime_option.UseCpu();
|
||||||
|
runtime_option.SetCpuThreadNum(1);
|
||||||
|
runtime_option.SetModelPath(encoder_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
|
||||||
|
assert(encoder_->Init(runtime_option));
|
||||||
|
runtime_option.SetModelPath(rescore_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
|
||||||
|
assert(rescore_->Init(runtime_option));
|
||||||
|
runtime_option.SetModelPath(ctc_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
|
||||||
|
assert(ctc_->Init(runtime_option));
|
||||||
|
} catch (std::exception const& e) {
|
||||||
|
LOG(ERROR) << "error when load onnx model: " << e.what();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Config conf(param_path);
|
||||||
|
encoder_output_size_ = conf.Read("output_size", encoder_output_size_);
|
||||||
|
num_blocks_ = conf.Read("num_blocks", num_blocks_);
|
||||||
|
head_ = conf.Read("head", head_);
|
||||||
|
cnn_module_kernel_ = conf.Read("cnn_module_kernel", cnn_module_kernel_);
|
||||||
|
subsampling_rate_ = conf.Read("subsampling_rate", subsampling_rate_);
|
||||||
|
right_context_ = conf.Read("right_context", right_context_);
|
||||||
|
sos_= conf.Read("sos_symbol", sos_);
|
||||||
|
eos_= conf.Read("eos_symbol", eos_);
|
||||||
|
is_bidecoder_= conf.Read("is_bidirectional_decoder", is_bidecoder_);
|
||||||
|
chunk_size_= conf.Read("chunk_size", chunk_size_);
|
||||||
|
num_left_chunks_ = conf.Read("left_chunks", num_left_chunks_);
|
||||||
|
|
||||||
|
LOG(INFO) << "Onnx Model Info:";
|
||||||
|
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
|
||||||
|
LOG(INFO) << "\tnum_blocks " << num_blocks_;
|
||||||
|
LOG(INFO) << "\thead " << head_;
|
||||||
|
LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_;
|
||||||
|
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
|
||||||
|
LOG(INFO) << "\tright_context " << right_context_;
|
||||||
|
LOG(INFO) << "\tsos " << sos_;
|
||||||
|
LOG(INFO) << "\teos " << eos_;
|
||||||
|
LOG(INFO) << "\tis bidirectional decoder " << is_bidecoder_;
|
||||||
|
LOG(INFO) << "\tchunk_size " << chunk_size_;
|
||||||
|
LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_;
|
||||||
|
|
||||||
|
// 3. Read model nodes
|
||||||
|
LOG(INFO) << "Onnx Encoder:";
|
||||||
|
GetInputOutputInfo(encoder_, &encoder_in_names_, &encoder_out_names_);
|
||||||
|
LOG(INFO) << "Onnx CTC:";
|
||||||
|
GetInputOutputInfo(ctc_, &ctc_in_names_, &ctc_out_names_);
|
||||||
|
LOG(INFO) << "Onnx Rescore:";
|
||||||
|
GetInputOutputInfo(rescore_, &rescore_in_names_, &rescore_out_names_);
|
||||||
|
}
|
||||||
|
|
||||||
|
U2OnnxNnet::U2OnnxNnet(const ModelOptions& opts) : opts_(opts) {
|
||||||
|
LoadModel(opts_.model_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// shallow copy
|
||||||
|
U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) {
|
||||||
|
// metadatas
|
||||||
|
encoder_output_size_ = other.encoder_output_size_;
|
||||||
|
num_blocks_ = other.num_blocks_;
|
||||||
|
head_ = other.head_;
|
||||||
|
cnn_module_kernel_ = other.cnn_module_kernel_;
|
||||||
|
right_context_ = other.right_context_;
|
||||||
|
subsampling_rate_ = other.subsampling_rate_;
|
||||||
|
sos_ = other.sos_;
|
||||||
|
eos_ = other.eos_;
|
||||||
|
is_bidecoder_ = other.is_bidecoder_;
|
||||||
|
chunk_size_ = other.chunk_size_;
|
||||||
|
num_left_chunks_ = other.num_left_chunks_;
|
||||||
|
offset_ = other.offset_;
|
||||||
|
|
||||||
|
// session
|
||||||
|
encoder_ = other.encoder_;
|
||||||
|
ctc_ = other.ctc_;
|
||||||
|
rescore_ = other.rescore_;
|
||||||
|
|
||||||
|
// node names
|
||||||
|
encoder_in_names_ = other.encoder_in_names_;
|
||||||
|
encoder_out_names_ = other.encoder_out_names_;
|
||||||
|
ctc_in_names_ = other.ctc_in_names_;
|
||||||
|
ctc_out_names_ = other.ctc_out_names_;
|
||||||
|
rescore_in_names_ = other.rescore_in_names_;
|
||||||
|
rescore_out_names_ = other.rescore_out_names_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2OnnxNnet::GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
|
||||||
|
std::vector<std::string>* in_names, std::vector<std::string>* out_names) {
|
||||||
|
std::vector<fastdeploy::TensorInfo> inputs_info = runtime->GetInputInfos();
|
||||||
|
(*in_names).resize(inputs_info.size());
|
||||||
|
for (int i = 0; i < inputs_info.size(); ++i){
|
||||||
|
fastdeploy::TensorInfo info = inputs_info[i];
|
||||||
|
|
||||||
|
std::stringstream shape;
|
||||||
|
for(int j = 0; j < info.shape.size(); ++j){
|
||||||
|
shape << info.shape[j];
|
||||||
|
shape << " ";
|
||||||
|
}
|
||||||
|
LOG(INFO) << "\tInput " << i << " : name=" << info.name << " type=" << info.dtype
|
||||||
|
<< " dims=" << shape.str();
|
||||||
|
(*in_names)[i] = info.name;
|
||||||
|
}
|
||||||
|
std::vector<fastdeploy::TensorInfo> outputs_info = runtime->GetOutputInfos();
|
||||||
|
(*out_names).resize(outputs_info.size());
|
||||||
|
for (int i = 0; i < outputs_info.size(); ++i){
|
||||||
|
fastdeploy::TensorInfo info = outputs_info[i];
|
||||||
|
|
||||||
|
std::stringstream shape;
|
||||||
|
for(int j = 0; j < info.shape.size(); ++j){
|
||||||
|
shape << info.shape[j];
|
||||||
|
shape << " ";
|
||||||
|
}
|
||||||
|
LOG(INFO) << "\tOutput " << i << " : name=" << info.name << " type=" << info.dtype
|
||||||
|
<< " dims=" << shape.str();
|
||||||
|
(*out_names)[i] = info.name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<NnetBase> U2OnnxNnet::Clone() const {
|
||||||
|
auto asr_model = std::make_shared<U2OnnxNnet>(*this);
|
||||||
|
// reset inner state for new decoding
|
||||||
|
asr_model->Reset();
|
||||||
|
return asr_model;
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2OnnxNnet::Reset() {
|
||||||
|
offset_ = 0;
|
||||||
|
encoder_outs_.clear();
|
||||||
|
cached_feats_.clear();
|
||||||
|
// Reset att_cache
|
||||||
|
if (num_left_chunks_ > 0) {
|
||||||
|
int required_cache_size = chunk_size_ * num_left_chunks_;
|
||||||
|
offset_ = required_cache_size;
|
||||||
|
att_cache_.resize(num_blocks_ * head_ * required_cache_size *
|
||||||
|
encoder_output_size_ / head_ * 2,
|
||||||
|
0.0);
|
||||||
|
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, required_cache_size,
|
||||||
|
encoder_output_size_ / head_ * 2};
|
||||||
|
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
|
||||||
|
} else {
|
||||||
|
att_cache_.resize(0, 0.0);
|
||||||
|
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, 0,
|
||||||
|
encoder_output_size_ / head_ * 2};
|
||||||
|
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset cnn_cache
|
||||||
|
cnn_cache_.resize(
|
||||||
|
num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0);
|
||||||
|
const std::vector<int64_t> cnn_cache_shape = {num_blocks_, 1, encoder_output_size_,
|
||||||
|
cnn_module_kernel_ - 1};
|
||||||
|
cnn_cache_ort_.SetExternalData(cnn_cache_shape, fastdeploy::FDDataType::FP32, cnn_cache_.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2OnnxNnet::FeedForward(const std::vector<BaseFloat>& features,
|
||||||
|
const int32& feature_dim,
|
||||||
|
NnetOut* out) {
|
||||||
|
kaldi::Timer timer;
|
||||||
|
|
||||||
|
std::vector<kaldi::BaseFloat> ctc_probs;
|
||||||
|
ForwardEncoderChunkImpl(
|
||||||
|
features, feature_dim, &out->logprobs, &out->vocab_dim);
|
||||||
|
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
|
||||||
|
<< features.size() / feature_dim << " frames.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2OnnxNnet::ForwardEncoderChunkImpl(
|
||||||
|
const std::vector<kaldi::BaseFloat>& chunk_feats,
|
||||||
|
const int32& feat_dim,
|
||||||
|
std::vector<kaldi::BaseFloat>* out_prob,
|
||||||
|
int32* vocab_dim) {
|
||||||
|
|
||||||
|
// 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
|
||||||
|
// chunk
|
||||||
|
int num_frames = chunk_feats.size() / feat_dim;
|
||||||
|
VLOG(3) << "num_frames: " << num_frames;
|
||||||
|
VLOG(3) << "feat_dim: " << feat_dim;
|
||||||
|
const int feature_dim = feat_dim;
|
||||||
|
std::vector<float> feats;
|
||||||
|
feats.insert(feats.end(), chunk_feats.begin(), chunk_feats.end());
|
||||||
|
fastdeploy::FDTensor feats_ort;
|
||||||
|
const std::vector<int64_t> feats_shape = {1, num_frames, feature_dim};
|
||||||
|
feats_ort.SetExternalData(feats_shape, fastdeploy::FDDataType::FP32, feats.data());
|
||||||
|
|
||||||
|
// offset
|
||||||
|
int64_t offset_int64 = static_cast<int64_t>(offset_);
|
||||||
|
fastdeploy::FDTensor offset_ort;
|
||||||
|
offset_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &offset_int64);
|
||||||
|
|
||||||
|
// required_cache_size
|
||||||
|
int64_t required_cache_size = chunk_size_ * num_left_chunks_;
|
||||||
|
fastdeploy::FDTensor required_cache_size_ort("");
|
||||||
|
required_cache_size_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &required_cache_size);
|
||||||
|
|
||||||
|
// att_mask
|
||||||
|
fastdeploy::FDTensor att_mask_ort;
|
||||||
|
std::vector<uint8_t> att_mask(required_cache_size + chunk_size_, 1);
|
||||||
|
if (num_left_chunks_ > 0) {
|
||||||
|
int chunk_idx = offset_ / chunk_size_ - num_left_chunks_;
|
||||||
|
if (chunk_idx < num_left_chunks_) {
|
||||||
|
for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
|
||||||
|
att_mask[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const std::vector<int64_t> att_mask_shape = {1, 1, required_cache_size + chunk_size_};
|
||||||
|
att_mask_ort.SetExternalData(att_mask_shape, fastdeploy::FDDataType::BOOL, reinterpret_cast<bool*>(att_mask.data()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Encoder chunk forward
|
||||||
|
std::vector<fastdeploy::FDTensor> inputs(encoder_in_names_.size());
|
||||||
|
for (int i = 0; i < encoder_in_names_.size(); ++i) {
|
||||||
|
std::string name = encoder_in_names_[i];
|
||||||
|
if (!strcmp(name.data(), "chunk")) {
|
||||||
|
inputs[i] = std::move(feats_ort);
|
||||||
|
inputs[i].name = "chunk";
|
||||||
|
} else if (!strcmp(name.data(), "offset")) {
|
||||||
|
inputs[i] = std::move(offset_ort);
|
||||||
|
inputs[i].name = "offset";
|
||||||
|
} else if (!strcmp(name.data(), "required_cache_size")) {
|
||||||
|
inputs[i] = std::move(required_cache_size_ort);
|
||||||
|
inputs[i].name = "required_cache_size";
|
||||||
|
} else if (!strcmp(name.data(), "att_cache")) {
|
||||||
|
inputs[i] = std::move(att_cache_ort_);
|
||||||
|
inputs[i].name = "att_cache";
|
||||||
|
} else if (!strcmp(name.data(), "cnn_cache")) {
|
||||||
|
inputs[i] = std::move(cnn_cache_ort_);
|
||||||
|
inputs[i].name = "cnn_cache";
|
||||||
|
} else if (!strcmp(name.data(), "att_mask")) {
|
||||||
|
inputs[i] = std::move(att_mask_ort);
|
||||||
|
inputs[i].name = "att_mask";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<fastdeploy::FDTensor> ort_outputs;
|
||||||
|
assert(encoder_->Infer(inputs, &ort_outputs));
|
||||||
|
|
||||||
|
offset_ += static_cast<int>(ort_outputs[0].shape[1]);
|
||||||
|
att_cache_ort_ = std::move(ort_outputs[1]);
|
||||||
|
cnn_cache_ort_ = std::move(ort_outputs[2]);
|
||||||
|
|
||||||
|
std::vector<fastdeploy::FDTensor> ctc_inputs;
|
||||||
|
ctc_inputs.emplace_back(std::move(ort_outputs[0]));
|
||||||
|
// ctc_inputs[0] = std::move(ort_outputs[0]);
|
||||||
|
ctc_inputs[0].name = ctc_in_names_[0];
|
||||||
|
|
||||||
|
std::vector<fastdeploy::FDTensor> ctc_ort_outputs;
|
||||||
|
assert(ctc_->Infer(ctc_inputs, &ctc_ort_outputs));
|
||||||
|
encoder_outs_.emplace_back(std::move(ctc_inputs[0])); // *****
|
||||||
|
|
||||||
|
float* logp_data = reinterpret_cast<float*>(ctc_ort_outputs[0].Data());
|
||||||
|
|
||||||
|
// Copy to output, (B=1,T,D)
|
||||||
|
std::vector<int64_t> ctc_log_probs_shape = ctc_ort_outputs[0].shape;
|
||||||
|
CHECK_EQ(ctc_log_probs_shape.size(), 3);
|
||||||
|
int B = ctc_log_probs_shape[0];
|
||||||
|
CHECK_EQ(B, 1);
|
||||||
|
int T = ctc_log_probs_shape[1];
|
||||||
|
int D = ctc_log_probs_shape[2];
|
||||||
|
*vocab_dim = D;
|
||||||
|
|
||||||
|
out_prob->resize(T * D);
|
||||||
|
std::memcpy(
|
||||||
|
out_prob->data(), logp_data, T * D * sizeof(kaldi::BaseFloat));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float U2OnnxNnet::ComputeAttentionScore(const float* prob,
|
||||||
|
const std::vector<int>& hyp, int eos,
|
||||||
|
int decode_out_len) {
|
||||||
|
float score = 0.0f;
|
||||||
|
for (size_t j = 0; j < hyp.size(); ++j) {
|
||||||
|
score += *(prob + j * decode_out_len + hyp[j]);
|
||||||
|
}
|
||||||
|
score += *(prob + hyp.size() * decode_out_len + eos);
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2OnnxNnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||||
|
float reverse_weight,
|
||||||
|
std::vector<float>* rescoring_score) {
|
||||||
|
CHECK(rescoring_score != nullptr);
|
||||||
|
int num_hyps = hyps.size();
|
||||||
|
rescoring_score->resize(num_hyps, 0.0f);
|
||||||
|
|
||||||
|
if (num_hyps == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// No encoder output
|
||||||
|
if (encoder_outs_.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> hyps_lens;
|
||||||
|
int max_hyps_len = 0;
|
||||||
|
for (size_t i = 0; i < num_hyps; ++i) {
|
||||||
|
int length = hyps[i].size() + 1;
|
||||||
|
max_hyps_len = std::max(length, max_hyps_len);
|
||||||
|
hyps_lens.emplace_back(static_cast<int64_t>(length));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> rescore_input;
|
||||||
|
int encoder_len = 0;
|
||||||
|
for (int i = 0; i < encoder_outs_.size(); i++) {
|
||||||
|
float* encoder_outs_data = reinterpret_cast<float*>(encoder_outs_[i].Data());
|
||||||
|
for (int j = 0; j < encoder_outs_[i].Numel(); j++) {
|
||||||
|
rescore_input.emplace_back(encoder_outs_data[j]);
|
||||||
|
}
|
||||||
|
encoder_len += encoder_outs_[i].shape[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> hyps_pad;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_hyps; ++i) {
|
||||||
|
const std::vector<int>& hyp = hyps[i];
|
||||||
|
hyps_pad.emplace_back(sos_);
|
||||||
|
size_t j = 0;
|
||||||
|
for (; j < hyp.size(); ++j) {
|
||||||
|
hyps_pad.emplace_back(hyp[j]);
|
||||||
|
}
|
||||||
|
if (j == max_hyps_len - 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (; j < max_hyps_len - 1; ++j) {
|
||||||
|
hyps_pad.emplace_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<int64_t> hyps_pad_shape = {num_hyps, max_hyps_len};
|
||||||
|
const std::vector<int64_t> hyps_lens_shape = {num_hyps};
|
||||||
|
const std::vector<int64_t> decode_input_shape = {1, encoder_len, encoder_output_size_};
|
||||||
|
|
||||||
|
fastdeploy::FDTensor hyps_pad_tensor_;
|
||||||
|
hyps_pad_tensor_.SetExternalData(hyps_pad_shape, fastdeploy::FDDataType::INT64, hyps_pad.data());
|
||||||
|
fastdeploy::FDTensor hyps_lens_tensor_;
|
||||||
|
hyps_lens_tensor_.SetExternalData(hyps_lens_shape, fastdeploy::FDDataType::INT64, hyps_lens.data());
|
||||||
|
fastdeploy::FDTensor decode_input_tensor_;
|
||||||
|
decode_input_tensor_.SetExternalData(decode_input_shape, fastdeploy::FDDataType::FP32, rescore_input.data());
|
||||||
|
|
||||||
|
std::vector<fastdeploy::FDTensor> rescore_inputs(3);
|
||||||
|
|
||||||
|
rescore_inputs[0] = std::move(hyps_pad_tensor_);
|
||||||
|
rescore_inputs[0].name = rescore_in_names_[0];
|
||||||
|
rescore_inputs[1] = std::move(hyps_lens_tensor_);
|
||||||
|
rescore_inputs[1].name = rescore_in_names_[1];
|
||||||
|
rescore_inputs[2] = std::move(decode_input_tensor_);
|
||||||
|
rescore_inputs[2].name = rescore_in_names_[2];
|
||||||
|
|
||||||
|
std::vector<fastdeploy::FDTensor> rescore_outputs;
|
||||||
|
assert(rescore_->Infer(rescore_inputs, &rescore_outputs));
|
||||||
|
|
||||||
|
float* decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[0].Data());
|
||||||
|
float* r_decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[1].Data());
|
||||||
|
|
||||||
|
int decode_out_len = rescore_outputs[0].shape[2];
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_hyps; ++i) {
|
||||||
|
const std::vector<int>& hyp = hyps[i];
|
||||||
|
float score = 0.0f;
|
||||||
|
// left to right decoder score
|
||||||
|
score = ComputeAttentionScore(
|
||||||
|
decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_,
|
||||||
|
decode_out_len);
|
||||||
|
// Optional: Used for right to left score
|
||||||
|
float r_score = 0.0f;
|
||||||
|
if (is_bidecoder_ && reverse_weight > 0) {
|
||||||
|
std::vector<int> r_hyp(hyp.size());
|
||||||
|
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
|
||||||
|
// right to left decoder score
|
||||||
|
r_score = ComputeAttentionScore(
|
||||||
|
r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_,
|
||||||
|
decode_out_len);
|
||||||
|
}
|
||||||
|
// combined left-to-right and right-to-left score
|
||||||
|
(*rescoring_score)[i] =
|
||||||
|
score * (1 - reverse_weight) + r_score * reverse_weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void U2OnnxNnet::EncoderOuts(
|
||||||
|
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
|
||||||
|
}
|
||||||
|
|
||||||
|
} //namepace ppspeech
|
@ -0,0 +1,97 @@
|
|||||||
|
// Copyright 2022 Horizon Robotics. All Rights Reserved.
|
||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// modified from
|
||||||
|
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "matrix/kaldi-matrix.h"
|
||||||
|
#include "nnet/nnet_itf.h"
|
||||||
|
#include "nnet/u2_nnet.h"
|
||||||
|
|
||||||
|
#include "fastdeploy/runtime.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class U2OnnxNnet : public U2NnetBase {
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit U2OnnxNnet(const ModelOptions& opts);
|
||||||
|
U2OnnxNnet(const U2OnnxNnet& other);
|
||||||
|
|
||||||
|
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
|
||||||
|
const int32& feature_dim,
|
||||||
|
NnetOut* out) override;
|
||||||
|
|
||||||
|
void Reset() override;
|
||||||
|
|
||||||
|
bool IsLogProb() override { return true; }
|
||||||
|
|
||||||
|
void Dim();
|
||||||
|
|
||||||
|
void LoadModel(const std::string& model_dir);
|
||||||
|
|
||||||
|
std::shared_ptr<NnetBase> Clone() const override;
|
||||||
|
|
||||||
|
void ForwardEncoderChunkImpl(
|
||||||
|
const std::vector<kaldi::BaseFloat>& chunk_feats,
|
||||||
|
const int32& feat_dim,
|
||||||
|
std::vector<kaldi::BaseFloat>* ctc_probs,
|
||||||
|
int32* vocab_dim) override;
|
||||||
|
|
||||||
|
float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
|
||||||
|
int eos, int decode_out_len);
|
||||||
|
|
||||||
|
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||||
|
float reverse_weight,
|
||||||
|
std::vector<float>* rescoring_score) override;
|
||||||
|
|
||||||
|
void EncoderOuts(
|
||||||
|
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
|
||||||
|
|
||||||
|
void GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
|
||||||
|
std::vector<std::string>* in_names,
|
||||||
|
std::vector<std::string>* out_names);
|
||||||
|
private:
|
||||||
|
ModelOptions opts_;
|
||||||
|
|
||||||
|
int encoder_output_size_ = 0;
|
||||||
|
int num_blocks_ = 0;
|
||||||
|
int cnn_module_kernel_ = 0;
|
||||||
|
int head_ = 0;
|
||||||
|
|
||||||
|
// sessions
|
||||||
|
std::shared_ptr<fastdeploy::Runtime> encoder_ = nullptr;
|
||||||
|
std::shared_ptr<fastdeploy::Runtime> rescore_ = nullptr;
|
||||||
|
std::shared_ptr<fastdeploy::Runtime> ctc_ = nullptr;
|
||||||
|
|
||||||
|
|
||||||
|
// node names
|
||||||
|
std::vector<std::string> encoder_in_names_, encoder_out_names_;
|
||||||
|
std::vector<std::string> ctc_in_names_, ctc_out_names_;
|
||||||
|
std::vector<std::string> rescore_in_names_, rescore_out_names_;
|
||||||
|
|
||||||
|
// caches
|
||||||
|
fastdeploy::FDTensor att_cache_ort_;
|
||||||
|
fastdeploy::FDTensor cnn_cache_ort_;
|
||||||
|
std::vector<fastdeploy::FDTensor> encoder_outs_;
|
||||||
|
|
||||||
|
std::vector<float> att_cache_;
|
||||||
|
std::vector<float> cnn_cache_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,26 @@
|
|||||||
|
set(srcs)
|
||||||
|
|
||||||
|
list(APPEND srcs
|
||||||
|
recognizer_controller.cc
|
||||||
|
recognizer_controller_impl.cc
|
||||||
|
recognizer_instance.cc
|
||||||
|
recognizer.cc
|
||||||
|
)
|
||||||
|
|
||||||
|
add_library(recognizer STATIC ${srcs})
|
||||||
|
target_link_libraries(recognizer PUBLIC decoder)
|
||||||
|
|
||||||
|
set(TEST_BINS
|
||||||
|
recognizer_batch_main
|
||||||
|
recognizer_batch_main2
|
||||||
|
recognizer_main
|
||||||
|
)
|
||||||
|
|
||||||
|
foreach(bin_name IN LISTS TEST_BINS)
|
||||||
|
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||||
|
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||||
|
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
|
||||||
|
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
|
||||||
|
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||||
|
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
|
||||||
|
endforeach()
|
@ -0,0 +1,46 @@
|
|||||||
|
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "recognizer/recognizer.h"
|
||||||
|
#include "recognizer/recognizer_instance.h"
|
||||||
|
|
||||||
|
bool InitRecognizer(const std::string& model_file,
|
||||||
|
const std::string& word_symbol_table_file,
|
||||||
|
const std::string& fst_file,
|
||||||
|
int num_instance) {
|
||||||
|
return ppspeech::RecognizerInstance::GetInstance().Init(model_file,
|
||||||
|
word_symbol_table_file,
|
||||||
|
fst_file,
|
||||||
|
num_instance);
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetRecognizerInstanceId() {
|
||||||
|
return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId();
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitDecoder(int instance_id) {
|
||||||
|
return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AcceptData(const std::vector<float>& waves, int instance_id) {
|
||||||
|
return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetInputFinished(int instance_id) {
|
||||||
|
return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetFinalResult(int instance_id) {
|
||||||
|
return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id);
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
bool InitRecognizer(const std::string& model_file,
|
||||||
|
const std::string& word_symbol_table_file,
|
||||||
|
const std::string& fst_file,
|
||||||
|
int num_instance);
|
||||||
|
int GetRecognizerInstanceId();
|
||||||
|
void InitDecoder(int instance_id);
|
||||||
|
void AcceptData(const std::vector<float>& waves, int instance_id);
|
||||||
|
void SetInputFinished(int instance_id);
|
||||||
|
std::string GetFinalResult(int instance_id);
|
@ -0,0 +1,172 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "common/base/thread_pool.h"
|
||||||
|
#include "common/utils/file_utils.h"
|
||||||
|
#include "common/utils/strings.h"
|
||||||
|
#include "decoder/param.h"
|
||||||
|
#include "frontend/wave-reader.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
#include "nnet/u2_nnet.h"
|
||||||
|
#include "recognizer/recognizer_controller.h"
|
||||||
|
|
||||||
|
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||||
|
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||||
|
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||||
|
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||||
|
DEFINE_int32(njob, 3, "njob");
|
||||||
|
|
||||||
|
using std::string;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
void SplitUtt(string wavlist_file,
|
||||||
|
vector<vector<string>>* uttlists,
|
||||||
|
vector<vector<string>>* wavlists,
|
||||||
|
int njob) {
|
||||||
|
vector<string> wavlist;
|
||||||
|
wavlists->resize(njob);
|
||||||
|
uttlists->resize(njob);
|
||||||
|
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
|
||||||
|
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||||
|
string utt_str = wavlist[idx];
|
||||||
|
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
|
||||||
|
LOG(INFO) << utt_wav[0];
|
||||||
|
CHECK_EQ(utt_wav.size(), size_t(2));
|
||||||
|
uttlists->at(idx % njob).push_back(utt_wav[0]);
|
||||||
|
wavlists->at(idx % njob).push_back(utt_wav[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
|
||||||
|
std::vector<string> wavlist,
|
||||||
|
std::vector<string> uttlist,
|
||||||
|
std::vector<string>* results) {
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
double tot_wav_duration = 0.0;
|
||||||
|
double tot_attention_rescore_time = 0.0;
|
||||||
|
double tot_decode_time = 0.0;
|
||||||
|
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
|
||||||
|
if (wavlist.empty()) return;
|
||||||
|
|
||||||
|
results->reserve(wavlist.size());
|
||||||
|
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||||
|
std::string utt = uttlist[idx];
|
||||||
|
std::string wav_file = wavlist[idx];
|
||||||
|
std::ifstream infile;
|
||||||
|
infile.open(wav_file, std::ifstream::in);
|
||||||
|
kaldi::WaveData wave_data;
|
||||||
|
wave_data.Read(infile);
|
||||||
|
int32 recog_id = -1;
|
||||||
|
while (recog_id == -1) {
|
||||||
|
recog_id = recognizer_controller->GetRecognizerInstanceId();
|
||||||
|
}
|
||||||
|
recognizer_controller->InitDecoder(recog_id);
|
||||||
|
LOG(INFO) << "utt: " << utt;
|
||||||
|
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||||
|
double dur = wave_data.Duration();
|
||||||
|
tot_wav_duration += dur;
|
||||||
|
|
||||||
|
int32 this_channel = 0;
|
||||||
|
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||||
|
this_channel);
|
||||||
|
int tot_samples = waveform.Dim();
|
||||||
|
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||||
|
|
||||||
|
int sample_offset = 0;
|
||||||
|
kaldi::Timer local_timer;
|
||||||
|
|
||||||
|
while (sample_offset < tot_samples) {
|
||||||
|
int cur_chunk_size =
|
||||||
|
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||||
|
|
||||||
|
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||||
|
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||||
|
wav_chunk[i] = waveform(sample_offset + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
recognizer_controller->Accept(wav_chunk, recog_id);
|
||||||
|
// no overlap
|
||||||
|
sample_offset += cur_chunk_size;
|
||||||
|
}
|
||||||
|
recognizer_controller->SetInputFinished(recog_id);
|
||||||
|
CHECK(sample_offset == tot_samples);
|
||||||
|
std::string result = recognizer_controller->GetFinalResult(recog_id);
|
||||||
|
if (result.empty()) {
|
||||||
|
// the TokenWriter can not write empty string.
|
||||||
|
++num_err;
|
||||||
|
LOG(INFO) << " the result of " << utt << " is empty";
|
||||||
|
result = " ";
|
||||||
|
}
|
||||||
|
|
||||||
|
tot_decode_time += local_timer.Elapsed();
|
||||||
|
LOG(INFO) << utt << " " << result;
|
||||||
|
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
|
||||||
|
<< " cost: " << local_timer.Elapsed();
|
||||||
|
|
||||||
|
results->push_back(result);
|
||||||
|
++num_done;
|
||||||
|
}
|
||||||
|
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
||||||
|
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
||||||
|
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
|
||||||
|
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::SetUsageMessage("Usage:");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
google::InstallFailureSignalHandler();
|
||||||
|
FLAGS_logtostderr = 1;
|
||||||
|
|
||||||
|
int sample_rate = FLAGS_sample_rate;
|
||||||
|
float streaming_chunk = FLAGS_streaming_chunk;
|
||||||
|
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||||
|
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||||
|
int njob = FLAGS_njob;
|
||||||
|
LOG(INFO) << "sr: " << sample_rate;
|
||||||
|
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||||
|
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||||
|
|
||||||
|
ppspeech::RecognizerResource resource =
|
||||||
|
ppspeech::RecognizerResource::InitFromFlags();
|
||||||
|
ppspeech::RecognizerController recognizer_controller(njob, resource);
|
||||||
|
ThreadPool threadpool(njob);
|
||||||
|
vector<vector<string>> wavlist;
|
||||||
|
vector<vector<string>> uttlist;
|
||||||
|
vector<vector<string>> resultlist(njob);
|
||||||
|
vector<std::future<void>> futurelist;
|
||||||
|
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
|
||||||
|
for (size_t i = 0; i < njob; ++i) {
|
||||||
|
std::future<void> f = threadpool.enqueue(recognizer_func,
|
||||||
|
&recognizer_controller,
|
||||||
|
wavlist[i],
|
||||||
|
uttlist[i],
|
||||||
|
&resultlist[i]);
|
||||||
|
futurelist.push_back(std::move(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < njob; ++i) {
|
||||||
|
futurelist[i].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t idx = 0; idx < njob; ++idx) {
|
||||||
|
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
|
||||||
|
string utt = uttlist[idx][utt_idx];
|
||||||
|
string result = resultlist[idx][utt_idx];
|
||||||
|
result_writer.Write(utt, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
@ -0,0 +1,168 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "common/base/thread_pool.h"
|
||||||
|
#include "common/utils/file_utils.h"
|
||||||
|
#include "common/utils/strings.h"
|
||||||
|
#include "decoder/param.h"
|
||||||
|
#include "frontend/wave-reader.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
#include "nnet/u2_nnet.h"
|
||||||
|
#include "recognizer/recognizer.h"
|
||||||
|
|
||||||
|
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||||
|
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||||
|
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||||
|
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||||
|
DEFINE_int32(njob, 3, "njob");
|
||||||
|
|
||||||
|
using std::string;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
void SplitUtt(string wavlist_file,
|
||||||
|
vector<vector<string>>* uttlists,
|
||||||
|
vector<vector<string>>* wavlists,
|
||||||
|
int njob) {
|
||||||
|
vector<string> wavlist;
|
||||||
|
wavlists->resize(njob);
|
||||||
|
uttlists->resize(njob);
|
||||||
|
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
|
||||||
|
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||||
|
string utt_str = wavlist[idx];
|
||||||
|
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
|
||||||
|
LOG(INFO) << utt_wav[0];
|
||||||
|
CHECK_EQ(utt_wav.size(), size_t(2));
|
||||||
|
uttlists->at(idx % njob).push_back(utt_wav[0]);
|
||||||
|
wavlists->at(idx % njob).push_back(utt_wav[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void recognizer_func(std::vector<string> wavlist,
|
||||||
|
std::vector<string> uttlist,
|
||||||
|
std::vector<string>* results) {
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
double tot_wav_duration = 0.0;
|
||||||
|
double tot_attention_rescore_time = 0.0;
|
||||||
|
double tot_decode_time = 0.0;
|
||||||
|
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
|
||||||
|
if (wavlist.empty()) return;
|
||||||
|
|
||||||
|
results->reserve(wavlist.size());
|
||||||
|
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||||
|
std::string utt = uttlist[idx];
|
||||||
|
std::string wav_file = wavlist[idx];
|
||||||
|
std::ifstream infile;
|
||||||
|
infile.open(wav_file, std::ifstream::in);
|
||||||
|
kaldi::WaveData wave_data;
|
||||||
|
wave_data.Read(infile);
|
||||||
|
int32 recog_id = -1;
|
||||||
|
while (recog_id == -1) {
|
||||||
|
recog_id = GetRecognizerInstanceId();
|
||||||
|
}
|
||||||
|
InitDecoder(recog_id);
|
||||||
|
LOG(INFO) << "utt: " << utt;
|
||||||
|
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||||
|
double dur = wave_data.Duration();
|
||||||
|
tot_wav_duration += dur;
|
||||||
|
|
||||||
|
int32 this_channel = 0;
|
||||||
|
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||||
|
this_channel);
|
||||||
|
int tot_samples = waveform.Dim();
|
||||||
|
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||||
|
|
||||||
|
int sample_offset = 0;
|
||||||
|
kaldi::Timer local_timer;
|
||||||
|
|
||||||
|
while (sample_offset < tot_samples) {
|
||||||
|
int cur_chunk_size =
|
||||||
|
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||||
|
|
||||||
|
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||||
|
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||||
|
wav_chunk[i] = waveform(sample_offset + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
AcceptData(wav_chunk, recog_id);
|
||||||
|
// no overlap
|
||||||
|
sample_offset += cur_chunk_size;
|
||||||
|
}
|
||||||
|
SetInputFinished(recog_id);
|
||||||
|
CHECK(sample_offset == tot_samples);
|
||||||
|
std::string result = GetFinalResult(recog_id);
|
||||||
|
if (result.empty()) {
|
||||||
|
// the TokenWriter can not write empty string.
|
||||||
|
++num_err;
|
||||||
|
LOG(INFO) << " the result of " << utt << " is empty";
|
||||||
|
result = " ";
|
||||||
|
}
|
||||||
|
|
||||||
|
tot_decode_time += local_timer.Elapsed();
|
||||||
|
LOG(INFO) << utt << " " << result;
|
||||||
|
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
|
||||||
|
<< " cost: " << local_timer.Elapsed();
|
||||||
|
|
||||||
|
results->push_back(result);
|
||||||
|
++num_done;
|
||||||
|
}
|
||||||
|
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
||||||
|
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
||||||
|
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
|
||||||
|
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::SetUsageMessage("Usage:");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
google::InstallFailureSignalHandler();
|
||||||
|
FLAGS_logtostderr = 1;
|
||||||
|
|
||||||
|
int sample_rate = FLAGS_sample_rate;
|
||||||
|
float streaming_chunk = FLAGS_streaming_chunk;
|
||||||
|
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||||
|
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||||
|
int njob = FLAGS_njob;
|
||||||
|
LOG(INFO) << "sr: " << sample_rate;
|
||||||
|
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||||
|
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||||
|
|
||||||
|
InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob);
|
||||||
|
ThreadPool threadpool(njob);
|
||||||
|
vector<vector<string>> wavlist;
|
||||||
|
vector<vector<string>> uttlist;
|
||||||
|
vector<vector<string>> resultlist(njob);
|
||||||
|
vector<std::future<void>> futurelist;
|
||||||
|
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
|
||||||
|
for (size_t i = 0; i < njob; ++i) {
|
||||||
|
std::future<void> f = threadpool.enqueue(recognizer_func,
|
||||||
|
wavlist[i],
|
||||||
|
uttlist[i],
|
||||||
|
&resultlist[i]);
|
||||||
|
futurelist.push_back(std::move(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < njob; ++i) {
|
||||||
|
futurelist[i].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t idx = 0; idx < njob; ++idx) {
|
||||||
|
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
|
||||||
|
string utt = uttlist[idx][utt_idx];
|
||||||
|
string result = resultlist[idx][utt_idx];
|
||||||
|
result_writer.Write(utt, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
@ -0,0 +1,70 @@
|
|||||||
|
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "recognizer/recognizer_controller.h"
|
||||||
|
#include "nnet/u2_nnet.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) {
|
||||||
|
recognizer_workers.resize(num_worker);
|
||||||
|
for (size_t i = 0; i < num_worker; ++i) {
|
||||||
|
recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource));
|
||||||
|
waiting_workers.push(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int RecognizerController::GetRecognizerInstanceId() {
|
||||||
|
if (waiting_workers.empty()) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
int idx = -1;
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
idx = waiting_workers.front();
|
||||||
|
waiting_workers.pop();
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
RecognizerController::~RecognizerController() {
|
||||||
|
for (size_t i = 0; i < recognizer_workers.size(); ++i) {
|
||||||
|
recognizer_workers[i]->WaitFinished();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecognizerController::InitDecoder(int idx) {
|
||||||
|
recognizer_workers[idx]->InitDecoder();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RecognizerController::GetFinalResult(int idx) {
|
||||||
|
recognizer_workers[idx]->WaitDecoderFinished();
|
||||||
|
recognizer_workers[idx]->AttentionRescoring();
|
||||||
|
std::string result = recognizer_workers[idx]->GetFinalResult();
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
waiting_workers.push(idx);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecognizerController::Accept(std::vector<float> data, int idx) {
|
||||||
|
recognizer_workers[idx]->Accept(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecognizerController::SetInputFinished(int idx) {
|
||||||
|
recognizer_workers[idx]->SetInputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,42 @@
|
|||||||
|
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "recognizer/recognizer_controller_impl.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class RecognizerController {
|
||||||
|
public:
|
||||||
|
explicit RecognizerController(int num_worker, RecognizerResource resource);
|
||||||
|
~RecognizerController();
|
||||||
|
int GetRecognizerInstanceId();
|
||||||
|
void InitDecoder(int idx);
|
||||||
|
void Accept(std::vector<float> data, int idx);
|
||||||
|
void SetInputFinished(int idx);
|
||||||
|
std::string GetFinalResult(int idx);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::queue<int> waiting_workers;
|
||||||
|
std::mutex mutex_;
|
||||||
|
std::vector<std::unique_ptr<ppspeech::RecognizerControllerImpl>> recognizer_workers;
|
||||||
|
|
||||||
|
DISALLOW_COPY_AND_ASSIGN(RecognizerController);
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "decoder/common.h"
|
||||||
|
#include "fst/fstlib.h"
|
||||||
|
#include "fst/symbol-table.h"
|
||||||
|
#include "nnet/u2_nnet.h"
|
||||||
|
#include "nnet/nnet_producer.h"
|
||||||
|
#ifdef USE_ONNX
|
||||||
|
#include "nnet/u2_onnx_nnet.h"
|
||||||
|
#endif
|
||||||
|
#include "nnet/decodable.h"
|
||||||
|
#include "recognizer/recognizer_resource.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class RecognizerControllerImpl {
|
||||||
|
public:
|
||||||
|
explicit RecognizerControllerImpl(const RecognizerResource& resource);
|
||||||
|
~RecognizerControllerImpl();
|
||||||
|
void Accept(std::vector<float> data);
|
||||||
|
void InitDecoder();
|
||||||
|
void SetInputFinished();
|
||||||
|
std::string GetFinalResult();
|
||||||
|
std::string GetPartialResult();
|
||||||
|
void Rescoring();
|
||||||
|
void Reset();
|
||||||
|
void WaitDecoderFinished();
|
||||||
|
void WaitFinished();
|
||||||
|
void AttentionRescoring();
|
||||||
|
bool DecodedSomething() const {
|
||||||
|
return !result_.empty() && !result_[0].sentence.empty();
|
||||||
|
}
|
||||||
|
int FrameShiftInMs() const {
|
||||||
|
return 1; //todo
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
static void RunNnetEvaluation(RecognizerControllerImpl* me);
|
||||||
|
void RunNnetEvaluationInternal();
|
||||||
|
static void RunDecoder(RecognizerControllerImpl* me);
|
||||||
|
void RunDecoderInternal();
|
||||||
|
void UpdateResult(bool finish = false);
|
||||||
|
|
||||||
|
std::shared_ptr<Decodable> decodable_;
|
||||||
|
std::unique_ptr<DecoderBase> decoder_;
|
||||||
|
std::shared_ptr<NnetProducer> nnet_producer_;
|
||||||
|
|
||||||
|
// e2e unit symbol table
|
||||||
|
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
|
||||||
|
std::vector<DecodeResult> result_;
|
||||||
|
|
||||||
|
RecognizerResource opts_;
|
||||||
|
bool abort_ = false;
|
||||||
|
// global decoded frame offset
|
||||||
|
int global_frame_offset_;
|
||||||
|
// cur decoded frame num
|
||||||
|
int num_frames_;
|
||||||
|
// timestamp gap between words in a sentence
|
||||||
|
const int time_stamp_gap_ = 100;
|
||||||
|
bool input_finished_;
|
||||||
|
|
||||||
|
std::mutex nnet_mutex_;
|
||||||
|
std::mutex decoder_mutex_;
|
||||||
|
std::condition_variable nnet_condition_;
|
||||||
|
std::condition_variable decoder_condition_;
|
||||||
|
std::thread nnet_thread_;
|
||||||
|
std::thread decoder_thread_;
|
||||||
|
|
||||||
|
DISALLOW_COPY_AND_ASSIGN(RecognizerControllerImpl);
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,66 @@
|
|||||||
|
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "recognizer/recognizer_instance.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
RecognizerInstance& RecognizerInstance::GetInstance() {
|
||||||
|
static RecognizerInstance instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RecognizerInstance::Init(const std::string& model_file,
|
||||||
|
const std::string& word_symbol_table_file,
|
||||||
|
const std::string& fst_file,
|
||||||
|
int num_instance) {
|
||||||
|
RecognizerResource resource = RecognizerResource::InitFromFlags();
|
||||||
|
resource.model_opts.model_path = model_file;
|
||||||
|
//resource.vocab_path = word_symbol_table_file;
|
||||||
|
if (!fst_file.empty()) {
|
||||||
|
resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file;
|
||||||
|
resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file;
|
||||||
|
} else {
|
||||||
|
resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table =
|
||||||
|
word_symbol_table_file;
|
||||||
|
}
|
||||||
|
recognizer_controller_ = std::make_unique<RecognizerController>(num_instance, resource);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecognizerInstance::InitDecoder(int idx) {
|
||||||
|
recognizer_controller_->InitDecoder(idx);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int RecognizerInstance::GetRecognizerInstanceId() {
|
||||||
|
return recognizer_controller_->GetRecognizerInstanceId();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecognizerInstance::Accept(const std::vector<float>& waves, int idx) const {
|
||||||
|
recognizer_controller_->Accept(waves, idx);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecognizerInstance::SetInputFinished(int idx) const {
|
||||||
|
recognizer_controller_->SetInputFinished(idx);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RecognizerInstance::GetResult(int idx) const {
|
||||||
|
return recognizer_controller_->GetFinalResult(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,42 @@
|
|||||||
|
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "recognizer/recognizer_controller.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class RecognizerInstance {
|
||||||
|
public:
|
||||||
|
static RecognizerInstance& GetInstance();
|
||||||
|
RecognizerInstance() {}
|
||||||
|
~RecognizerInstance() {}
|
||||||
|
bool Init(const std::string& model_file,
|
||||||
|
const std::string& word_symbol_table_file,
|
||||||
|
const std::string& fst_file,
|
||||||
|
int num_instance);
|
||||||
|
int GetRecognizerInstanceId();
|
||||||
|
void InitDecoder(int idx);
|
||||||
|
void Accept(const std::vector<float>& waves, int idx) const;
|
||||||
|
void SetInputFinished(int idx) const;
|
||||||
|
std::string GetResult(int idx) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<RecognizerController> recognizer_controller_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1 @@
|
|||||||
|
#add_subdirectory(websocket)
|
@ -0,0 +1,3 @@
|
|||||||
|
# add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND")
|
||||||
|
add_definitions("-DUSE_ORT_BACKEND")
|
||||||
|
add_subdirectory(nnet)
|
@ -0,0 +1,11 @@
|
|||||||
|
set(srcs
|
||||||
|
panns_nnet.cc
|
||||||
|
panns_interface.cc
|
||||||
|
)
|
||||||
|
|
||||||
|
add_library(cls SHARED ${srcs})
|
||||||
|
target_link_libraries(cls PRIVATE ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils )
|
||||||
|
|
||||||
|
set(bin_name panns_nnet_main)
|
||||||
|
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||||
|
target_link_libraries(${bin_name} gflags glog cls)
|
@ -0,0 +1,79 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "audio_classification/nnet/panns_interface.h"
|
||||||
|
|
||||||
|
#include "audio_classification/nnet/panns_nnet.h"
|
||||||
|
#include "common/base/config.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
void* ClsCreateInstance(const char* conf_path) {
|
||||||
|
Config conf(conf_path);
|
||||||
|
// cls init
|
||||||
|
ppspeech::ClsNnetConf cls_nnet_conf;
|
||||||
|
cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true);
|
||||||
|
cls_nnet_conf.wav_normal_type_ =
|
||||||
|
conf.Read("wav_normal_type", std::string("linear"));
|
||||||
|
cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0);
|
||||||
|
cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string(""));
|
||||||
|
cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string(""));
|
||||||
|
cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string(""));
|
||||||
|
cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12);
|
||||||
|
cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000);
|
||||||
|
cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32);
|
||||||
|
cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10);
|
||||||
|
cls_nnet_conf.num_bins = conf.Read("num_bins", 64);
|
||||||
|
cls_nnet_conf.low_freq = conf.Read("low_freq", 50);
|
||||||
|
cls_nnet_conf.high_freq = conf.Read("high_freq", 14000);
|
||||||
|
cls_nnet_conf.dither = conf.Read("dither", 0.0);
|
||||||
|
|
||||||
|
ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet();
|
||||||
|
int ret = cls_model->Init(cls_nnet_conf);
|
||||||
|
return static_cast<void*>(cls_model);
|
||||||
|
}
|
||||||
|
|
||||||
|
int ClsDestroyInstance(void* instance) {
|
||||||
|
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
||||||
|
if (cls_model != NULL) {
|
||||||
|
delete cls_model;
|
||||||
|
cls_model = NULL;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ClsFeedForward(void* instance,
|
||||||
|
const char* wav_path,
|
||||||
|
int topk,
|
||||||
|
char* result,
|
||||||
|
int result_max_len) {
|
||||||
|
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
||||||
|
if (cls_model == NULL) {
|
||||||
|
printf("instance is null\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
int ret = cls_model->Forward(wav_path, topk, result, result_max_len);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ClsReset(void* instance) {
|
||||||
|
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
||||||
|
if (cls_model == NULL) {
|
||||||
|
printf("instance is null\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
cls_model->Reset();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,27 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
void* ClsCreateInstance(const char* conf_path);
|
||||||
|
int ClsDestroyInstance(void* instance);
|
||||||
|
int ClsFeedForward(void* instance,
|
||||||
|
const char* wav_path,
|
||||||
|
int topk,
|
||||||
|
char* result,
|
||||||
|
int result_max_len);
|
||||||
|
int ClsReset(void* instance);
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,227 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "audio_classification/nnet/panns_nnet.h"
|
||||||
|
#ifdef WITH_PROFILING
|
||||||
|
#include "kaldi/base/timer.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
ClsNnet::ClsNnet() {
|
||||||
|
// wav_reader_ = NULL;
|
||||||
|
runtime_ = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ClsNnet::Reset() {
|
||||||
|
// wav_reader_->Clear();
|
||||||
|
ss_.str("");
|
||||||
|
}
|
||||||
|
|
||||||
|
int ClsNnet::Init(const ClsNnetConf& conf) {
|
||||||
|
conf_ = conf;
|
||||||
|
// init fbank opts
|
||||||
|
fbank_opts_.frame_opts.samp_freq = conf.samp_freq;
|
||||||
|
fbank_opts_.frame_opts.frame_length_ms = conf.frame_length_ms;
|
||||||
|
fbank_opts_.frame_opts.frame_shift_ms = conf.frame_shift_ms;
|
||||||
|
fbank_opts_.mel_opts.num_bins = conf.num_bins;
|
||||||
|
fbank_opts_.mel_opts.low_freq = conf.low_freq;
|
||||||
|
fbank_opts_.mel_opts.high_freq = conf.high_freq;
|
||||||
|
fbank_opts_.frame_opts.dither = conf.dither;
|
||||||
|
fbank_opts_.use_log_fbank = false;
|
||||||
|
|
||||||
|
// init dict
|
||||||
|
if (conf.dict_file_path_ != "") {
|
||||||
|
ReadFileToVector(conf.dict_file_path_, &dict_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// init model
|
||||||
|
fastdeploy::RuntimeOption runtime_option;
|
||||||
|
|
||||||
|
#ifdef USE_PADDLE_INFERENCE_BACKEND
|
||||||
|
runtime_option.SetModelPath(conf.model_file_path_,
|
||||||
|
conf.param_file_path_,
|
||||||
|
fastdeploy::ModelFormat::PADDLE);
|
||||||
|
runtime_option.UsePaddleInferBackend();
|
||||||
|
#elif defined(USE_ORT_BACKEND)
|
||||||
|
runtime_option.SetModelPath(
|
||||||
|
conf.model_file_path_, "", fastdeploy::ModelFormat::ONNX); // onnx
|
||||||
|
runtime_option.UseOrtBackend(); // onnx
|
||||||
|
#elif defined(USE_PADDLE_LITE_BACKEND)
|
||||||
|
runtime_option.SetModelPath(conf.model_file_path_,
|
||||||
|
conf.param_file_path_,
|
||||||
|
fastdeploy::ModelFormat::PADDLE);
|
||||||
|
runtime_option.UseLiteBackend();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
runtime_option.SetCpuThreadNum(conf.num_cpu_thread_);
|
||||||
|
// runtime_option.DeletePaddleBackendPass("simplify_with_basic_ops_pass");
|
||||||
|
runtime_ = std::unique_ptr<fastdeploy::Runtime>(new fastdeploy::Runtime());
|
||||||
|
if (!runtime_->Init(runtime_option)) {
|
||||||
|
std::cerr << "--- Init FastDeploy Runitme Failed! "
|
||||||
|
<< "\n--- Model: " << conf.model_file_path_ << std::endl;
|
||||||
|
return -1;
|
||||||
|
} else {
|
||||||
|
std::cout << "--- Init FastDeploy Runitme Done! "
|
||||||
|
<< "\n--- Model: " << conf.model_file_path_ << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
Reset();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ClsNnet::Forward(const char* wav_path,
|
||||||
|
int topk,
|
||||||
|
char* result,
|
||||||
|
int result_max_len) {
|
||||||
|
#ifdef WITH_PROFILING
|
||||||
|
kaldi::Timer timer;
|
||||||
|
timer.Reset();
|
||||||
|
#endif
|
||||||
|
// read wav
|
||||||
|
std::ifstream infile(wav_path, std::ifstream::in);
|
||||||
|
kaldi::WaveData wave_data;
|
||||||
|
wave_data.Read(infile);
|
||||||
|
int32 this_channel = 0;
|
||||||
|
kaldi::Matrix<float> wavform_kaldi = wave_data.Data();
|
||||||
|
// only get channel 0
|
||||||
|
int wavform_len = wavform_kaldi.NumCols();
|
||||||
|
std::vector<float> wavform(wavform_kaldi.Data(),
|
||||||
|
wavform_kaldi.Data() + wavform_len);
|
||||||
|
WaveformFloatNormal(&wavform);
|
||||||
|
WaveformNormal(&wavform,
|
||||||
|
conf_.wav_normal_,
|
||||||
|
conf_.wav_normal_type_,
|
||||||
|
conf_.wav_norm_mul_factor_);
|
||||||
|
#ifdef PPS_DEBUG
|
||||||
|
{
|
||||||
|
std::ofstream fp("cls.wavform", std::ios::out);
|
||||||
|
for (int i = 0; i < wavform.size(); ++i) {
|
||||||
|
fp << std::setprecision(18) << wavform[i] << " ";
|
||||||
|
}
|
||||||
|
fp << "\n";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#ifdef WITH_PROFILING
|
||||||
|
printf("wav read consume: %fs\n", timer.Elapsed());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef WITH_PROFILING
|
||||||
|
timer.Reset();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::vector<float> feats;
|
||||||
|
std::unique_ptr<ppspeech::FrontendInterface> data_source(
|
||||||
|
new ppspeech::DataCache());
|
||||||
|
ppspeech::Fbank fbank(fbank_opts_, std::move(data_source));
|
||||||
|
fbank.Accept(wavform);
|
||||||
|
fbank.SetFinished();
|
||||||
|
fbank.Read(&feats);
|
||||||
|
|
||||||
|
int feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||||
|
int num_frames = feats.size() / feat_dim;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_frames; ++i) {
|
||||||
|
for (int j = 0; j < feat_dim; ++j) {
|
||||||
|
feats[i * feat_dim + j] = PowerTodb(feats[i * feat_dim + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#ifdef PPS_DEBUG
|
||||||
|
{
|
||||||
|
std::ofstream fp("cls.feat", std::ios::out);
|
||||||
|
for (int i = 0; i < num_frames; ++i) {
|
||||||
|
for (int j = 0; j < feat_dim; ++j) {
|
||||||
|
fp << std::setprecision(18) << feats[i * feat_dim + j] << " ";
|
||||||
|
}
|
||||||
|
fp << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#ifdef WITH_PROFILING
|
||||||
|
printf("extract fbank consume: %fs\n", timer.Elapsed());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// infer
|
||||||
|
std::vector<float> model_out;
|
||||||
|
#ifdef WITH_PROFILING
|
||||||
|
timer.Reset();
|
||||||
|
#endif
|
||||||
|
ModelForward(feats.data(), num_frames, feat_dim, &model_out);
|
||||||
|
#ifdef WITH_PROFILING
|
||||||
|
printf("fast deploy infer consume: %fs\n", timer.Elapsed());
|
||||||
|
#endif
|
||||||
|
#ifdef PPS_DEBUG
|
||||||
|
{
|
||||||
|
std::ofstream fp("cls.logits", std::ios::out);
|
||||||
|
for (int i = 0; i < model_out.size(); ++i) {
|
||||||
|
fp << std::setprecision(18) << model_out[i] << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// construct result str
|
||||||
|
ss_ << "{";
|
||||||
|
GetTopkResult(topk, model_out);
|
||||||
|
ss_ << "}";
|
||||||
|
|
||||||
|
if (result_max_len <= ss_.str().size()) {
|
||||||
|
printf("result_max_len is short than result len\n");
|
||||||
|
}
|
||||||
|
snprintf(result, result_max_len, "%s", ss_.str().c_str());
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ClsNnet::ModelForward(float* features,
|
||||||
|
const int num_frames,
|
||||||
|
const int feat_dim,
|
||||||
|
std::vector<float>* model_out) {
|
||||||
|
// init input tensor shape
|
||||||
|
fastdeploy::TensorInfo info = runtime_->GetInputInfo(0);
|
||||||
|
info.shape = {1, num_frames, feat_dim};
|
||||||
|
|
||||||
|
std::vector<fastdeploy::FDTensor> input_tensors(1);
|
||||||
|
std::vector<fastdeploy::FDTensor> output_tensors(1);
|
||||||
|
|
||||||
|
input_tensors[0].SetExternalData({1, num_frames, feat_dim},
|
||||||
|
fastdeploy::FDDataType::FP32,
|
||||||
|
static_cast<void*>(features));
|
||||||
|
|
||||||
|
// get input name
|
||||||
|
input_tensors[0].name = info.name;
|
||||||
|
|
||||||
|
runtime_->Infer(input_tensors, &output_tensors);
|
||||||
|
|
||||||
|
// output_tensors[0].PrintInfo();
|
||||||
|
std::vector<int64_t> output_shape = output_tensors[0].Shape();
|
||||||
|
model_out->resize(output_shape[0] * output_shape[1]);
|
||||||
|
memcpy(static_cast<void*>(model_out->data()),
|
||||||
|
output_tensors[0].Data(),
|
||||||
|
output_shape[0] * output_shape[1] * sizeof(float));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ClsNnet::GetTopkResult(int k, const std::vector<float>& model_out) {
|
||||||
|
std::vector<float> values;
|
||||||
|
std::vector<int> indics;
|
||||||
|
TopK(model_out, k, &values, &indics);
|
||||||
|
for (int i = 0; i < k; ++i) {
|
||||||
|
if (i != 0) {
|
||||||
|
ss_ << ",";
|
||||||
|
}
|
||||||
|
ss_ << "\"" << dict_[indics[i]] << "\":\"" << values[i] << "\"";
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,74 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common/frontend/data_cache.h"
|
||||||
|
#include "common/frontend/fbank.h"
|
||||||
|
#include "common/frontend/feature-fbank.h"
|
||||||
|
#include "common/frontend/frontend_itf.h"
|
||||||
|
#include "common/frontend/wave-reader.h"
|
||||||
|
#include "common/utils/audio_process.h"
|
||||||
|
#include "common/utils/file_utils.h"
|
||||||
|
#include "fastdeploy/runtime.h"
|
||||||
|
#include "kaldi/util/kaldi-io.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
struct ClsNnetConf {
|
||||||
|
// wav
|
||||||
|
bool wav_normal_;
|
||||||
|
std::string wav_normal_type_;
|
||||||
|
float wav_norm_mul_factor_;
|
||||||
|
// model
|
||||||
|
std::string model_file_path_;
|
||||||
|
std::string param_file_path_;
|
||||||
|
std::string dict_file_path_;
|
||||||
|
int num_cpu_thread_;
|
||||||
|
// fbank
|
||||||
|
float samp_freq;
|
||||||
|
float frame_length_ms;
|
||||||
|
float frame_shift_ms;
|
||||||
|
int num_bins;
|
||||||
|
float low_freq;
|
||||||
|
float high_freq;
|
||||||
|
float dither;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ClsNnet {
|
||||||
|
public:
|
||||||
|
ClsNnet();
|
||||||
|
int Init(const ClsNnetConf& conf);
|
||||||
|
int Forward(const char* wav_path,
|
||||||
|
int topk,
|
||||||
|
char* result,
|
||||||
|
int result_max_len);
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
private:
|
||||||
|
int ModelForward(float* features,
|
||||||
|
const int num_frames,
|
||||||
|
const int feat_dim,
|
||||||
|
std::vector<float>* model_out);
|
||||||
|
int ModelForwardStream(std::vector<float>* feats);
|
||||||
|
int GetTopkResult(int k, const std::vector<float>& model_out);
|
||||||
|
|
||||||
|
ClsNnetConf conf_;
|
||||||
|
knf::FbankOptions fbank_opts_;
|
||||||
|
std::unique_ptr<fastdeploy::Runtime> runtime_;
|
||||||
|
std::vector<std::string> dict_;
|
||||||
|
std::stringstream ss_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "gflags/gflags.h"
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "audio_classification/nnet/panns_interface.h"
|
||||||
|
|
||||||
|
DEFINE_string(conf_path, "", "config path");
|
||||||
|
DEFINE_string(scp_path, "", "wav scp path");
|
||||||
|
DEFINE_string(topk, "", "print topk results");
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::SetUsageMessage("Usage:");
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
google::InstallFailureSignalHandler();
|
||||||
|
FLAGS_logtostderr = 1;
|
||||||
|
CHECK_GT(FLAGS_conf_path.size(), 0);
|
||||||
|
CHECK_GT(FLAGS_scp_path.size(), 0);
|
||||||
|
CHECK_GT(FLAGS_topk.size(), 0);
|
||||||
|
void* instance = ppspeech::ClsCreateInstance(FLAGS_conf_path.c_str());
|
||||||
|
int ret = 0;
|
||||||
|
// read wav
|
||||||
|
std::ifstream ifs(FLAGS_scp_path);
|
||||||
|
std::string line = "";
|
||||||
|
int topk = std::atoi(FLAGS_topk.c_str());
|
||||||
|
while (getline(ifs, line)) {
|
||||||
|
// read wav
|
||||||
|
char result[1024] = {0};
|
||||||
|
ret = ppspeech::ClsFeedForward(
|
||||||
|
instance, line.c_str(), topk, result, 1024);
|
||||||
|
printf("%s %s\n", line.c_str(), result);
|
||||||
|
ret = ppspeech::ClsReset(instance);
|
||||||
|
}
|
||||||
|
ret = ppspeech::ClsDestroyInstance(instance);
|
||||||
|
return 0;
|
||||||
|
}
|
@ -0,0 +1,6 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
else() #Unix
|
||||||
|
add_subdirectory(glog)
|
||||||
|
endif()
|
@ -1,8 +1,8 @@
|
|||||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||||
|
|
||||||
add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc)
|
add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc)
|
||||||
target_link_libraries(glog_main glog)
|
target_link_libraries(glog_main extern_glog)
|
||||||
|
|
||||||
|
|
||||||
add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc)
|
add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc)
|
||||||
target_link_libraries(glog_logtostderr_main glog)
|
target_link_libraries(glog_logtostderr_main extern_glog)
|
@ -0,0 +1,19 @@
|
|||||||
|
include_directories(
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../
|
||||||
|
)
|
||||||
|
add_subdirectory(base)
|
||||||
|
add_subdirectory(utils)
|
||||||
|
add_subdirectory(matrix)
|
||||||
|
|
||||||
|
include_directories(
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/frontend
|
||||||
|
)
|
||||||
|
add_subdirectory(frontend)
|
||||||
|
|
||||||
|
add_library(common INTERFACE)
|
||||||
|
target_link_libraries(common INTERFACE base utils kaldi-matrix frontend)
|
||||||
|
install(TARGETS base DESTINATION lib)
|
||||||
|
install(TARGETS utils DESTINATION lib)
|
||||||
|
install(TARGETS kaldi-matrix DESTINATION lib)
|
||||||
|
install(TARGETS frontend DESTINATION lib)
|
@ -0,0 +1,43 @@
|
|||||||
|
|
||||||
|
|
||||||
|
if(WITH_ASR)
|
||||||
|
add_compile_options(-DWITH_ASR)
|
||||||
|
set(PPS_FLAGS_LIB "fst/flags.h")
|
||||||
|
else()
|
||||||
|
set(PPS_FLAGS_LIB "gflags/gflags.h")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
set(PPS_GLOG_LIB "base/log_impl.h")
|
||||||
|
else() #UNIX
|
||||||
|
if(WITH_ASR)
|
||||||
|
set(PPS_GLOG_LIB "fst/log.h")
|
||||||
|
else()
|
||||||
|
set(PPS_GLOG_LIB "glog/logging.h")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
configure_file(
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/flags.h.in
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/flags.h @ONLY
|
||||||
|
)
|
||||||
|
message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/flags.h")
|
||||||
|
|
||||||
|
configure_file(
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/log.h.in
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/log.h @ONLY
|
||||||
|
)
|
||||||
|
message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/log.h")
|
||||||
|
|
||||||
|
|
||||||
|
if(ANDROID)
|
||||||
|
set(csrc
|
||||||
|
log_impl.cc
|
||||||
|
glog_utils.cc
|
||||||
|
)
|
||||||
|
add_library(base ${csrc})
|
||||||
|
target_link_libraries(base gflags)
|
||||||
|
else() # UNIX
|
||||||
|
set(csrc)
|
||||||
|
add_library(base INTERFACE)
|
||||||
|
endif()
|
@ -0,0 +1,343 @@
|
|||||||
|
// Copyright (c) code is from
|
||||||
|
// https://blog.csdn.net/huixingshao/article/details/45969887.
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#pragma region ParseIniFile
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*
|
||||||
|
* \brief Generic configuration Class
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class Config {
|
||||||
|
// Data
|
||||||
|
protected:
|
||||||
|
std::string m_Delimiter; //!< separator between key and value
|
||||||
|
std::string m_Comment; //!< separator between value and comments
|
||||||
|
std::map<std::string, std::string>
|
||||||
|
m_Contents; //!< extracted keys and values
|
||||||
|
|
||||||
|
typedef std::map<std::string, std::string>::iterator mapi;
|
||||||
|
typedef std::map<std::string, std::string>::const_iterator mapci;
|
||||||
|
// Methods
|
||||||
|
public:
|
||||||
|
Config(std::string filename,
|
||||||
|
std::string delimiter = "=",
|
||||||
|
std::string comment = "#");
|
||||||
|
Config();
|
||||||
|
template <class T>
|
||||||
|
T Read(const std::string& in_key) const; //!< Search for key and read value
|
||||||
|
//! or optional default value, call
|
||||||
|
//! as read<T>
|
||||||
|
template <class T>
|
||||||
|
T Read(const std::string& in_key, const T& in_value) const;
|
||||||
|
template <class T>
|
||||||
|
bool ReadInto(T* out_var, const std::string& in_key) const;
|
||||||
|
template <class T>
|
||||||
|
bool ReadInto(T* out_var,
|
||||||
|
const std::string& in_key,
|
||||||
|
const T& in_value) const;
|
||||||
|
bool FileExist(std::string filename);
|
||||||
|
void ReadFile(std::string filename,
|
||||||
|
std::string delimiter = "=",
|
||||||
|
std::string comment = "#");
|
||||||
|
|
||||||
|
// Check whether key exists in configuration
|
||||||
|
bool KeyExists(const std::string& in_key) const;
|
||||||
|
|
||||||
|
// Modify keys and values
|
||||||
|
template <class T>
|
||||||
|
void Add(const std::string& in_key, const T& in_value);
|
||||||
|
void Remove(const std::string& in_key);
|
||||||
|
|
||||||
|
// Check or change configuration syntax
|
||||||
|
std::string GetDelimiter() const { return m_Delimiter; }
|
||||||
|
std::string GetComment() const { return m_Comment; }
|
||||||
|
std::string SetDelimiter(const std::string& in_s) {
|
||||||
|
std::string old = m_Delimiter;
|
||||||
|
m_Delimiter = in_s;
|
||||||
|
return old;
|
||||||
|
}
|
||||||
|
std::string SetComment(const std::string& in_s) {
|
||||||
|
std::string old = m_Comment;
|
||||||
|
m_Comment = in_s;
|
||||||
|
return old;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write or read configuration
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Config& cf);
|
||||||
|
friend std::istream& operator>>(std::istream& is, Config& cf);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
template <class T>
|
||||||
|
static std::string T_as_string(const T& t);
|
||||||
|
template <class T>
|
||||||
|
static T string_as_T(const std::string& s);
|
||||||
|
static void Trim(std::string* inout_s);
|
||||||
|
|
||||||
|
|
||||||
|
// Exception types
|
||||||
|
public:
|
||||||
|
struct File_not_found {
|
||||||
|
std::string filename;
|
||||||
|
explicit File_not_found(const std::string& filename_ = std::string())
|
||||||
|
: filename(filename_) {}
|
||||||
|
};
|
||||||
|
struct Key_not_found { // thrown only by T read(key) variant of read()
|
||||||
|
std::string key;
|
||||||
|
explicit Key_not_found(const std::string& key_ = std::string())
|
||||||
|
: key(key_) {}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
template <class T>
|
||||||
|
std::string Config::T_as_string(const T& t) {
|
||||||
|
// Convert from a T to a string
|
||||||
|
// Type T must support << operator
|
||||||
|
std::ostringstream ost;
|
||||||
|
ost << t;
|
||||||
|
return ost.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
template <class T>
|
||||||
|
T Config::string_as_T(const std::string& s) {
|
||||||
|
// Convert from a string to a T
|
||||||
|
// Type T must support >> operator
|
||||||
|
T t;
|
||||||
|
std::istringstream ist(s);
|
||||||
|
ist >> t;
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
template <>
|
||||||
|
inline std::string Config::string_as_T<std::string>(const std::string& s) {
|
||||||
|
// Convert from a string to a string
|
||||||
|
// In other words, do nothing
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
template <>
|
||||||
|
inline bool Config::string_as_T<bool>(const std::string& s) {
|
||||||
|
// Convert from a string to a bool
|
||||||
|
// Interpret "false", "F", "no", "n", "0" as false
|
||||||
|
// Interpret "true", "T", "yes", "y", "1", "-1", or anything else as true
|
||||||
|
bool b = true;
|
||||||
|
std::string sup = s;
|
||||||
|
for (std::string::iterator p = sup.begin(); p != sup.end(); ++p)
|
||||||
|
*p = toupper(*p); // make string all caps
|
||||||
|
if (sup == std::string("FALSE") || sup == std::string("F") ||
|
||||||
|
sup == std::string("NO") || sup == std::string("N") ||
|
||||||
|
sup == std::string("0") || sup == std::string("NONE"))
|
||||||
|
b = false;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T Config::Read(const std::string& key) const {
|
||||||
|
// Read the value corresponding to key
|
||||||
|
mapci p = m_Contents.find(key);
|
||||||
|
if (p == m_Contents.end()) throw Key_not_found(key);
|
||||||
|
return string_as_T<T>(p->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T Config::Read(const std::string& key, const T& value) const {
|
||||||
|
// Return the value corresponding to key or given default value
|
||||||
|
// if key is not found
|
||||||
|
mapci p = m_Contents.find(key);
|
||||||
|
if (p == m_Contents.end()) {
|
||||||
|
printf("%s = %s(default)\n", key.c_str(), T_as_string(value).c_str());
|
||||||
|
return value;
|
||||||
|
} else {
|
||||||
|
printf("%s = %s\n", key.c_str(), T_as_string(p->second).c_str());
|
||||||
|
return string_as_T<T>(p->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool Config::ReadInto(T* var, const std::string& key) const {
|
||||||
|
// Get the value corresponding to key and store in var
|
||||||
|
// Return true if key is found
|
||||||
|
// Otherwise leave var untouched
|
||||||
|
mapci p = m_Contents.find(key);
|
||||||
|
bool found = (p != m_Contents.end());
|
||||||
|
if (found) *var = string_as_T<T>(p->second);
|
||||||
|
return found;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool Config::ReadInto(T* var, const std::string& key, const T& value) const {
|
||||||
|
// Get the value corresponding to key and store in var
|
||||||
|
// Return true if key is found
|
||||||
|
// Otherwise set var to given default
|
||||||
|
mapci p = m_Contents.find(key);
|
||||||
|
bool found = (p != m_Contents.end());
|
||||||
|
if (found)
|
||||||
|
*var = string_as_T<T>(p->second);
|
||||||
|
else
|
||||||
|
var = value;
|
||||||
|
return found;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void Config::Add(const std::string& in_key, const T& value) {
|
||||||
|
// Add a key with given value
|
||||||
|
std::string v = T_as_string(value);
|
||||||
|
std::string key = in_key;
|
||||||
|
Trim(&key);
|
||||||
|
Trim(&v);
|
||||||
|
m_Contents[key] = v;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Config::Config(string filename, string delimiter, string comment)
|
||||||
|
: m_Delimiter(delimiter), m_Comment(comment) {
|
||||||
|
// Construct a Config, getting keys and values from given file
|
||||||
|
|
||||||
|
std::ifstream in(filename.c_str());
|
||||||
|
|
||||||
|
if (!in) throw File_not_found(filename);
|
||||||
|
|
||||||
|
in >> (*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Config::Config() : m_Delimiter(string(1, '=')), m_Comment(string(1, '#')) {
|
||||||
|
// Construct a Config without a file; empty
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool Config::KeyExists(const string& key) const {
|
||||||
|
// Indicate whether key is found
|
||||||
|
mapci p = m_Contents.find(key);
|
||||||
|
return (p != m_Contents.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
void Config::Trim(string* inout_s) {
|
||||||
|
// Remove leading and trailing whitespace
|
||||||
|
static const char whitespace[] = " \n\t\v\r\f";
|
||||||
|
inout_s->erase(0, inout_s->find_first_not_of(whitespace));
|
||||||
|
inout_s->erase(inout_s->find_last_not_of(whitespace) + 1U);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Config& cf) {
|
||||||
|
// Save a Config to os
|
||||||
|
for (Config::mapci p = cf.m_Contents.begin(); p != cf.m_Contents.end();
|
||||||
|
++p) {
|
||||||
|
os << p->first << " " << cf.m_Delimiter << " ";
|
||||||
|
os << p->second << std::endl;
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Config::Remove(const string& key) {
|
||||||
|
// Remove key and its value
|
||||||
|
m_Contents.erase(m_Contents.find(key));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::istream& operator>>(std::istream& is, Config& cf) {
|
||||||
|
// Load a Config from is
|
||||||
|
// Read in keys and values, keeping internal whitespace
|
||||||
|
typedef string::size_type pos;
|
||||||
|
const string& delim = cf.m_Delimiter; // separator
|
||||||
|
const string& comm = cf.m_Comment; // comment
|
||||||
|
const pos skip = delim.length(); // length of separator
|
||||||
|
|
||||||
|
string nextline = ""; // might need to read ahead to see where value ends
|
||||||
|
|
||||||
|
while (is || nextline.length() > 0) {
|
||||||
|
// Read an entire line at a time
|
||||||
|
string line;
|
||||||
|
if (nextline.length() > 0) {
|
||||||
|
line = nextline; // we read ahead; use it now
|
||||||
|
nextline = "";
|
||||||
|
} else {
|
||||||
|
std::getline(is, line);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore comments
|
||||||
|
line = line.substr(0, line.find(comm));
|
||||||
|
|
||||||
|
// Parse the line if it contains a delimiter
|
||||||
|
pos delimPos = line.find(delim);
|
||||||
|
if (delimPos < string::npos) {
|
||||||
|
// Extract the key
|
||||||
|
string key = line.substr(0, delimPos);
|
||||||
|
line.replace(0, delimPos + skip, "");
|
||||||
|
|
||||||
|
// See if value continues on the next line
|
||||||
|
// Stop at blank line, next line with a key, end of stream,
|
||||||
|
// or end of file sentry
|
||||||
|
bool terminate = false;
|
||||||
|
while (!terminate && is) {
|
||||||
|
std::getline(is, nextline);
|
||||||
|
terminate = true;
|
||||||
|
|
||||||
|
string nlcopy = nextline;
|
||||||
|
Config::Trim(&nlcopy);
|
||||||
|
if (nlcopy == "") continue;
|
||||||
|
|
||||||
|
nextline = nextline.substr(0, nextline.find(comm));
|
||||||
|
if (nextline.find(delim) != string::npos) continue;
|
||||||
|
|
||||||
|
nlcopy = nextline;
|
||||||
|
Config::Trim(&nlcopy);
|
||||||
|
if (nlcopy != "") line += "\n";
|
||||||
|
line += nextline;
|
||||||
|
terminate = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store key and value
|
||||||
|
Config::Trim(&key);
|
||||||
|
Config::Trim(&line);
|
||||||
|
cf.m_Contents[key] = line; // overwrites if key is repeated
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
bool Config::FileExist(std::string filename) {
|
||||||
|
bool exist = false;
|
||||||
|
std::ifstream in(filename.c_str());
|
||||||
|
if (in) exist = true;
|
||||||
|
return exist;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Config::ReadFile(string filename, string delimiter, string comment) {
|
||||||
|
m_Delimiter = delimiter;
|
||||||
|
m_Comment = comment;
|
||||||
|
std::ifstream in(filename.c_str());
|
||||||
|
|
||||||
|
if (!in) throw File_not_found(filename);
|
||||||
|
|
||||||
|
in >> (*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#pragma endregion ParseIniFIle
|
||||||
|
#endif
|
@ -0,0 +1,12 @@
|
|||||||
|
|
||||||
|
#include "base/glog_utils.h"
|
||||||
|
|
||||||
|
namespace google {
|
||||||
|
void InitGoogleLogging(const char* name) {
|
||||||
|
LOG(INFO) << "dummpy InitGoogleLogging.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void InstallFailureSignalHandler() {
|
||||||
|
LOG(INFO) << "dummpy InstallFailureSignalHandler.";
|
||||||
|
}
|
||||||
|
} // namespace google
|
@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
|
||||||
|
namespace google {
|
||||||
|
void InitGoogleLogging(const char* name);
|
||||||
|
|
||||||
|
void InstallFailureSignalHandler();
|
||||||
|
} // namespace google
|
@ -0,0 +1,105 @@
|
|||||||
|
#include "base/log.h"
|
||||||
|
|
||||||
|
DEFINE_int32(logtostderr, 0, "logging to stderr");
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
static char __progname[] = "paddlespeech";
|
||||||
|
|
||||||
|
namespace log {
|
||||||
|
|
||||||
|
std::mutex LogMessage::lock_;
|
||||||
|
std::string LogMessage::s_debug_logfile_("");
|
||||||
|
std::string LogMessage::s_info_logfile_("");
|
||||||
|
std::string LogMessage::s_warning_logfile_("");
|
||||||
|
std::string LogMessage::s_error_logfile_("");
|
||||||
|
std::string LogMessage::s_fatal_logfile_("");
|
||||||
|
|
||||||
|
void LogMessage::get_curr_proc_info(std::string* pid, std::string* proc_name) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << getpid();
|
||||||
|
ss >> *pid;
|
||||||
|
*proc_name = ::ppspeech::__progname;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogMessage::LogMessage(const char* file,
|
||||||
|
int line,
|
||||||
|
Severity level,
|
||||||
|
bool verbose,
|
||||||
|
bool out_to_file /* = false */)
|
||||||
|
: level_(level), verbose_(verbose), out_to_file_(out_to_file) {
|
||||||
|
if (FLAGS_logtostderr == 0) {
|
||||||
|
stream_ = static_cast<std::ostream*>(&std::cout);
|
||||||
|
} else if (FLAGS_logtostderr == 1) {
|
||||||
|
stream_ = static_cast<std::ostream*>(&std::cerr);
|
||||||
|
} else if (out_to_file_) {
|
||||||
|
// logfile
|
||||||
|
lock_.lock();
|
||||||
|
init(file, line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LogMessage::~LogMessage() {
|
||||||
|
stream() << std::endl;
|
||||||
|
|
||||||
|
if (out_to_file_) {
|
||||||
|
lock_.unlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verbose_ && level_ == FATAL) {
|
||||||
|
std::abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream* LogMessage::nullstream() {
|
||||||
|
thread_local static std::ofstream os;
|
||||||
|
thread_local static bool flag_set = false;
|
||||||
|
if (!flag_set) {
|
||||||
|
os.setstate(std::ios_base::badbit);
|
||||||
|
flag_set = true;
|
||||||
|
}
|
||||||
|
return &os;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogMessage::init(const char* file, int line) {
|
||||||
|
time_t t = time(0);
|
||||||
|
char tmp[100];
|
||||||
|
strftime(tmp, sizeof(tmp), "%Y%m%d-%H%M%S", localtime(&t));
|
||||||
|
|
||||||
|
if (s_info_logfile_.empty()) {
|
||||||
|
std::string pid;
|
||||||
|
std::string proc_name;
|
||||||
|
get_curr_proc_info(&pid, &proc_name);
|
||||||
|
|
||||||
|
s_debug_logfile_ =
|
||||||
|
std::string("log." + proc_name + ".log.DEBUG." + tmp + "." + pid);
|
||||||
|
s_info_logfile_ =
|
||||||
|
std::string("log." + proc_name + ".log.INFO." + tmp + "." + pid);
|
||||||
|
s_warning_logfile_ =
|
||||||
|
std::string("log." + proc_name + ".log.WARNING." + tmp + "." + pid);
|
||||||
|
s_error_logfile_ =
|
||||||
|
std::string("log." + proc_name + ".log.ERROR." + tmp + "." + pid);
|
||||||
|
s_fatal_logfile_ =
|
||||||
|
std::string("log." + proc_name + ".log.FATAL." + tmp + "." + pid);
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_local static std::ofstream ofs;
|
||||||
|
if (level_ == DEBUG) {
|
||||||
|
ofs.open(s_debug_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||||
|
} else if (level_ == INFO) {
|
||||||
|
ofs.open(s_info_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||||
|
} else if (level_ == WARNING) {
|
||||||
|
ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||||
|
} else if (level_ == ERROR) {
|
||||||
|
ofs.open(s_error_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||||
|
} else {
|
||||||
|
ofs.open(s_fatal_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream_ = &ofs;
|
||||||
|
|
||||||
|
stream() << tmp << " " << file << " line " << line << "; ";
|
||||||
|
stream() << std::flush;
|
||||||
|
}
|
||||||
|
} // namespace log
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,173 @@
|
|||||||
|
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// modified from https://github.com/Dounm/dlog
|
||||||
|
// modified form
|
||||||
|
// https://android.googlesource.com/platform/art/+/806defa/src/logging.h
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mutex>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "base/macros.h"
|
||||||
|
#ifndef WITH_GLOG
|
||||||
|
#include "base/glog_utils.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
DECLARE_int32(logtostderr);
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
namespace log {
|
||||||
|
|
||||||
|
enum Severity {
|
||||||
|
DEBUG,
|
||||||
|
INFO,
|
||||||
|
WARNING,
|
||||||
|
ERROR,
|
||||||
|
FATAL,
|
||||||
|
NUM_SEVERITIES,
|
||||||
|
};
|
||||||
|
|
||||||
|
class LogMessage {
|
||||||
|
public:
|
||||||
|
static void get_curr_proc_info(std::string* pid, std::string* proc_name);
|
||||||
|
|
||||||
|
LogMessage(const char* file,
|
||||||
|
int line,
|
||||||
|
Severity level,
|
||||||
|
bool verbose,
|
||||||
|
bool out_to_file = false);
|
||||||
|
|
||||||
|
~LogMessage();
|
||||||
|
|
||||||
|
std::ostream& stream() { return verbose_ ? *stream_ : *nullstream(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void init(const char* file, int line);
|
||||||
|
std::ostream* nullstream();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::ostream* stream_;
|
||||||
|
std::ostream* null_stream_;
|
||||||
|
Severity level_;
|
||||||
|
bool verbose_;
|
||||||
|
bool out_to_file_;
|
||||||
|
|
||||||
|
static std::mutex lock_; // stream write lock
|
||||||
|
static std::string s_debug_logfile_;
|
||||||
|
static std::string s_info_logfile_;
|
||||||
|
static std::string s_warning_logfile_;
|
||||||
|
static std::string s_error_logfile_;
|
||||||
|
static std::string s_fatal_logfile_;
|
||||||
|
|
||||||
|
DISALLOW_COPY_AND_ASSIGN(LogMessage);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace log
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef PPS_DEBUG
|
||||||
|
#define DLOG_INFO \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, false)
|
||||||
|
#define DLOG_WARNING \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, false)
|
||||||
|
#define DLOG_ERROR \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, false)
|
||||||
|
#define DLOG_FATAL \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, false)
|
||||||
|
#else
|
||||||
|
#define DLOG_INFO \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true)
|
||||||
|
#define DLOG_WARNING \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true)
|
||||||
|
#define DLOG_ERROR \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true)
|
||||||
|
#define DLOG_FATAL \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#define LOG_INFO \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true)
|
||||||
|
#define LOG_WARNING \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true)
|
||||||
|
#define LOG_ERROR \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true)
|
||||||
|
#define LOG_FATAL \
|
||||||
|
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true)
|
||||||
|
|
||||||
|
|
||||||
|
#define LOG_0 LOG_DEBUG
|
||||||
|
#define LOG_1 LOG_INFO
|
||||||
|
#define LOG_2 LOG_WARNING
|
||||||
|
#define LOG_3 LOG_ERROR
|
||||||
|
#define LOG_4 LOG_FATAL
|
||||||
|
|
||||||
|
#define LOG(level) LOG_##level.stream()
|
||||||
|
|
||||||
|
#define DLOG(level) DLOG_##level.stream()
|
||||||
|
|
||||||
|
#define VLOG(verboselevel) LOG(verboselevel)
|
||||||
|
|
||||||
|
#define CHECK(exp) \
|
||||||
|
ppspeech::log::LogMessage( \
|
||||||
|
__FILE__, __LINE__, ppspeech::log::FATAL, !(exp)) \
|
||||||
|
.stream() \
|
||||||
|
<< "Check Failed: " #exp
|
||||||
|
|
||||||
|
#define CHECK_EQ(x, y) CHECK((x) == (y))
|
||||||
|
#define CHECK_NE(x, y) CHECK((x) != (y))
|
||||||
|
#define CHECK_LE(x, y) CHECK((x) <= (y))
|
||||||
|
#define CHECK_LT(x, y) CHECK((x) < (y))
|
||||||
|
#define CHECK_GE(x, y) CHECK((x) >= (y))
|
||||||
|
#define CHECK_GT(x, y) CHECK((x) > (y))
|
||||||
|
#ifdef PPS_DEBUG
|
||||||
|
#define DCHECK(x) CHECK(x)
|
||||||
|
#define DCHECK_EQ(x, y) CHECK_EQ(x, y)
|
||||||
|
#define DCHECK_NE(x, y) CHECK_NE(x, y)
|
||||||
|
#define DCHECK_LE(x, y) CHECK_LE(x, y)
|
||||||
|
#define DCHECK_LT(x, y) CHECK_LT(x, y)
|
||||||
|
#define DCHECK_GE(x, y) CHECK_GE(x, y)
|
||||||
|
#define DCHECK_GT(x, y) CHECK_GT(x, y)
|
||||||
|
#else
|
||||||
|
#define DCHECK(condition) \
|
||||||
|
while (false) CHECK(condition)
|
||||||
|
#define DCHECK_EQ(val1, val2) \
|
||||||
|
while (false) CHECK_EQ(val1, val2)
|
||||||
|
#define DCHECK_NE(val1, val2) \
|
||||||
|
while (false) CHECK_NE(val1, val2)
|
||||||
|
#define DCHECK_LE(val1, val2) \
|
||||||
|
while (false) CHECK_LE(val1, val2)
|
||||||
|
#define DCHECK_LT(val1, val2) \
|
||||||
|
while (false) CHECK_LT(val1, val2)
|
||||||
|
#define DCHECK_GE(val1, val2) \
|
||||||
|
while (false) CHECK_GE(val1, val2)
|
||||||
|
#define DCHECK_GT(val1, val2) \
|
||||||
|
while (false) CHECK_GT(val1, val2)
|
||||||
|
#define DCHECK_STREQ(str1, str2) \
|
||||||
|
while (false) CHECK_STREQ(str1, str2)
|
||||||
|
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue