commit
7cab869d63
@ -0,0 +1,7 @@
|
||||
engine/common/base/flags.h
|
||||
engine/common/base/log.h
|
||||
|
||||
tools/valgrind*
|
||||
*log
|
||||
fc_patch/*
|
||||
test
|
@ -0,0 +1,211 @@
|
||||
# >=3.17 support -DCMAKE_FIND_DEBUG_MODE=ON
|
||||
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
|
||||
|
||||
set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0077.cmake")
|
||||
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
include(system)
|
||||
|
||||
project(paddlespeech VERSION 0.1)
|
||||
|
||||
set(PPS_VERSION_MAJOR 1)
|
||||
set(PPS_VERSION_MINOR 0)
|
||||
set(PPS_VERSION_PATCH 0)
|
||||
set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}")
|
||||
|
||||
# compiler option
|
||||
# Keep the same with openfst, -fPIC or -fpic
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl")
|
||||
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
|
||||
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")
|
||||
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
set(CMAKE_FIND_DEBUG_MODE OFF)
|
||||
set(PPS_CXX_STANDARD 14)
|
||||
|
||||
# set std-14
|
||||
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
|
||||
|
||||
# Ninja Generator will set CMAKE_BUILD_TYPE to Debug
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE)
|
||||
endif()
|
||||
|
||||
# find_* e.g. find_library work when Cross-Compiling
|
||||
if(ANDROID)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
|
||||
endif()
|
||||
|
||||
if(BUILD_IN_MACOS)
|
||||
add_definitions("-DOS_MACOSX")
|
||||
endif()
|
||||
|
||||
# install dir into `build/install`
|
||||
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install)
|
||||
|
||||
include(FetchContent)
|
||||
include(ExternalProject)
|
||||
|
||||
# fc_patch dir
|
||||
set(FETCHCONTENT_QUIET off)
|
||||
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
|
||||
set(FETCHCONTENT_BASE_DIR ${fc_patch})
|
||||
|
||||
###############################################################################
|
||||
# Option Configurations
|
||||
###############################################################################
|
||||
# https://github.com/google/brotli/pull/655
|
||||
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
|
||||
|
||||
option(WITH_PPS_DEBUG "debug option" OFF)
|
||||
if (WITH_PPS_DEBUG)
|
||||
add_definitions("-DPPS_DEBUG")
|
||||
endif()
|
||||
|
||||
option(WITH_ASR "build asr" ON)
|
||||
option(WITH_CLS "build cls" ON)
|
||||
option(WITH_VAD "build vad" ON)
|
||||
|
||||
option(WITH_GPU "NNet using GPU." OFF)
|
||||
|
||||
option(WITH_PROFILING "enable c++ profling" OFF)
|
||||
option(WITH_TESTING "unit test" ON)
|
||||
|
||||
option(WITH_ONNX "u2 support onnx runtime" OFF)
|
||||
|
||||
###############################################################################
|
||||
# Include Third Party
|
||||
###############################################################################
|
||||
include(gflags)
|
||||
|
||||
include(glog)
|
||||
|
||||
include(pybind)
|
||||
|
||||
#onnx
|
||||
if(WITH_ONNX)
|
||||
add_definitions(-DUSE_ONNX)
|
||||
endif()
|
||||
|
||||
# gtest
|
||||
if(WITH_TESTING)
|
||||
include(gtest) # download, build, install gtest
|
||||
endif()
|
||||
|
||||
# fastdeploy
|
||||
include(fastdeploy)
|
||||
|
||||
if(WITH_ASR)
|
||||
# openfst
|
||||
include(openfst)
|
||||
add_dependencies(openfst gflags extern_glog)
|
||||
endif()
|
||||
|
||||
###############################################################################
|
||||
# Find Package
|
||||
###############################################################################
|
||||
# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
if(WITH_ASR)
|
||||
# https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3
|
||||
find_package(Python3 COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG)
|
||||
|
||||
if(Python3_FOUND)
|
||||
message(STATUS "Python3_FOUND = ${Python3_FOUND}")
|
||||
message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
|
||||
message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}")
|
||||
message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}")
|
||||
message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}")
|
||||
set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE)
|
||||
set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE)
|
||||
endif()
|
||||
|
||||
message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
|
||||
message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}")
|
||||
include_directories(${PYTHON_INCLUDE_DIR})
|
||||
|
||||
if(pybind11_FOUND)
|
||||
message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}")
|
||||
message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}")
|
||||
message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
|
||||
endif()
|
||||
|
||||
|
||||
# paddle libpaddle.so
|
||||
# paddle include and link option
|
||||
# -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so
|
||||
set(EXECUTE_COMMAND "import os"
|
||||
"import paddle"
|
||||
"include_dir = paddle.sysconfig.get_include()"
|
||||
"paddle_dir=os.path.split(include_dir)[0]"
|
||||
"libs_dir=os.path.join(paddle_dir, 'libs')"
|
||||
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
|
||||
"out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir])"
|
||||
"out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out)"
|
||||
)
|
||||
execute_process(
|
||||
COMMAND python -c "${EXECUTE_COMMAND}"
|
||||
OUTPUT_VARIABLE PADDLE_LINK_FLAGS
|
||||
RESULT_VARIABLE SUCESS)
|
||||
|
||||
message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
|
||||
string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
|
||||
|
||||
# paddle compile option
|
||||
# -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include
|
||||
set(EXECUTE_COMMAND "import paddle"
|
||||
"include_dir = paddle.sysconfig.get_include()"
|
||||
"print(f\"-I{include_dir}\")"
|
||||
)
|
||||
execute_process(
|
||||
COMMAND python -c "${EXECUTE_COMMAND}"
|
||||
OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
|
||||
message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
|
||||
string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
|
||||
|
||||
# for LD_LIBRARY_PATH
|
||||
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
|
||||
set(EXECUTE_COMMAND "import os"
|
||||
"import paddle"
|
||||
"include_dir=paddle.sysconfig.get_include()"
|
||||
"paddle_dir=os.path.split(include_dir)[0]"
|
||||
"libs_dir=os.path.join(paddle_dir, 'libs')"
|
||||
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
|
||||
"out=':'.join([libs_dir, fluid_dir]); print(out)"
|
||||
)
|
||||
execute_process(
|
||||
COMMAND python -c "${EXECUTE_COMMAND}"
|
||||
OUTPUT_VARIABLE PADDLE_LIB_DIRS)
|
||||
message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
|
||||
endif()
|
||||
|
||||
include(summary)
|
||||
|
||||
###############################################################################
|
||||
# Add local library
|
||||
###############################################################################
|
||||
set(ENGINE_ROOT ${CMAKE_SOURCE_DIR}/engine)
|
||||
|
||||
add_subdirectory(engine)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# CPack library
|
||||
###############################################################################
|
||||
# build a CPack driven installer package
|
||||
include (InstallRequiredSystemLibraries)
|
||||
set(CPACK_PACKAGE_NAME "paddlespeech_library")
|
||||
set(CPACK_PACKAGE_VENDOR "paddlespeech")
|
||||
set(CPACK_PACKAGE_VERSION_MAJOR 1)
|
||||
set(CPACK_PACKAGE_VERSION_MINOR 0)
|
||||
set(CPACK_PACKAGE_VERSION_PATCH 0)
|
||||
set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library")
|
||||
set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com")
|
||||
set(CPACK_SOURCE_GENERATOR "TGZ")
|
||||
include (CPack)
|
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xe
|
||||
|
||||
BUILD_ROOT=build/Linux
|
||||
BUILD_DIR=${BUILD_ROOT}/x86_64
|
||||
|
||||
mkdir -p ${BUILD_DIR}
|
||||
|
||||
BUILD_TYPE=Release
|
||||
#BUILD_TYPE=Debug
|
||||
BUILD_SO=OFF
|
||||
BUILD_ONNX=ON
|
||||
BUILD_ASR=ON
|
||||
BUILD_CLS=ON
|
||||
BUILD_VAD=ON
|
||||
PPS_DEBUG=OFF
|
||||
FASTDEPLOY_INSTALL_DIR=""
|
||||
|
||||
# the build script had verified in the paddlepaddle docker image.
|
||||
# please follow the instruction below to install PaddlePaddle image.
|
||||
# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
|
||||
#cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install
|
||||
cmake -B ${BUILD_DIR} \
|
||||
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
|
||||
-DBUILD_SHARED_LIBS=${BUILD_SO} \
|
||||
-DWITH_ONNX=${BUILD_ONNX} \
|
||||
-DWITH_ASR=${BUILD_ASR} \
|
||||
-DWITH_CLS=${BUILD_CLS} \
|
||||
-DWITH_VAD=${BUILD_VAD} \
|
||||
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
|
||||
-DWITH_PPS_DEBUG=${PPS_DEBUG}
|
||||
|
||||
cmake --build ${BUILD_DIR} -j
|
@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
ANDROID_NDK=/mnt/masimeng/workspace/software/android-ndk-r25b/
|
||||
|
||||
# Setting up Android toolchanin
|
||||
ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a'
|
||||
ANDROID_PLATFORM="android-21" # API >= 21
|
||||
ANDROID_STL=c++_shared # 'c++_shared', 'c++_static'
|
||||
ANDROID_TOOLCHAIN=clang # 'clang' only
|
||||
TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
|
||||
|
||||
# Create build directory
|
||||
BUILD_ROOT=build/Android
|
||||
BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21
|
||||
FASTDEPLOY_INSTALL_DIR="/mnt/masimeng/workspace/FastDeploy/build/Android/arm64-v8a-api-21/install"
|
||||
|
||||
mkdir -p ${BUILD_DIR}
|
||||
cd ${BUILD_DIR}
|
||||
|
||||
# CMake configuration with Android toolchain
|
||||
cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DANDROID_ABI=${ANDROID_ABI} \
|
||||
-DANDROID_NDK=${ANDROID_NDK} \
|
||||
-DANDROID_PLATFORM=${ANDROID_PLATFORM} \
|
||||
-DANDROID_STL=${ANDROID_STL} \
|
||||
-DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DWITH_ASR=OFF \
|
||||
-DWITH_CLS=OFF \
|
||||
-DWITH_VAD=ON \
|
||||
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
|
||||
-DCMAKE_FIND_DEBUG_MODE=OFF \
|
||||
-Wno-dev ../../..
|
||||
|
||||
# Build FastDeploy Android C++ SDK
|
||||
make
|
@ -0,0 +1,91 @@
|
||||
# https://www.jianshu.com/p/33672fb819f5
|
||||
|
||||
PATH="/Applications/CMake.app/Contents/bin":"$PATH"
|
||||
tools_dir=$1
|
||||
ios_toolchain_cmake=${tools_dir}/"/ios-cmake-4.2.0/ios.toolchain.cmake"
|
||||
fastdeploy_dir=${tools_dir}"/fastdeploy-ort-mac-build/"
|
||||
build_targets=("OS64")
|
||||
build_type_array=("Release")
|
||||
|
||||
#static_name="libocr"
|
||||
#lib_name="libocr"
|
||||
|
||||
# Switch to workpath
|
||||
current_path=`cd $(dirname $0);pwd`
|
||||
work_path=${current_path}/
|
||||
build_path=${current_path}/build/
|
||||
output_path=${current_path}/output/
|
||||
cd ${work_path}
|
||||
|
||||
# Clean
|
||||
rm -rf ${build_path}
|
||||
rm -rf ${output_path}
|
||||
|
||||
if [ "$1"x = "clean"x ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Build Every Target
|
||||
for target in "${build_targets[@]}"
|
||||
do
|
||||
for build_type in "${build_type_array[@]}"
|
||||
do
|
||||
echo -e "\033[1;36;40mBuilding ${build_type} ${target} ... \033[0m"
|
||||
target_build_path=${build_path}/${target}/${build_type}/
|
||||
mkdir -p ${target_build_path}
|
||||
|
||||
cd ${target_build_path}
|
||||
if [ $? -ne 0 ];then
|
||||
echo -e "\033[1;31;40mcd ${target_build_path} failed \033[0m"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
if [ ${target} == "OS64" ];then
|
||||
fastdeploy_install_dir=${fastdeploy_dir}/arm64
|
||||
else
|
||||
fastdeploy_install_dir=""
|
||||
echo "fastdeploy_install_dir is null"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
cmake -DCMAKE_TOOLCHAIN_FILE=${ios_toolchain_cmake} \
|
||||
-DBUILD_IN_MACOS=ON \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DWITH_ASR=OFF \
|
||||
-DWITH_CLS=OFF \
|
||||
-DWITH_VAD=ON \
|
||||
-DFASTDEPLOY_INSTALL_DIR=${fastdeploy_install_dir} \
|
||||
-DPLATFORM=${target} ../../../
|
||||
|
||||
cmake --build . --config ${build_type}
|
||||
|
||||
mkdir output
|
||||
cp engine/vad/interface/libpps_vad_interface.a output
|
||||
cp engine/vad/interface/vad_interface_main.app/vad_interface_main output
|
||||
cp ${fastdeploy_install_dir}/lib/libfastdeploy.dylib output
|
||||
cp ${fastdeploy_install_dir}/third_libs/install/onnxruntime/lib/libonnxruntime.dylib output
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
## combine all ios libraries
|
||||
#DEVROOT=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/
|
||||
#LIPO_TOOL=${DEVROOT}/usr/bin/lipo
|
||||
#LIBRARY_PATH=${build_path}
|
||||
#LIBRARY_OUTPUT_PATH=${output_path}/IOS
|
||||
#mkdir -p ${LIBRARY_OUTPUT_PATH}
|
||||
#
|
||||
#${LIPO_TOOL} \
|
||||
# -arch i386 ${LIBRARY_PATH}/ios_x86/Release/${lib_name}.a \
|
||||
# -arch x86_64 ${LIBRARY_PATH}/ios_x86_64/Release/${lib_name}.a \
|
||||
# -arch armv7 ${LIBRARY_PATH}/ios_armv7/Release/${lib_name}.a \
|
||||
# -arch armv7s ${LIBRARY_PATH}/ios_armv7s/Release/${lib_name}.a \
|
||||
# -arch arm64 ${LIBRARY_PATH}/ios_armv8/Release/${lib_name}.a \
|
||||
# -output ${LIBRARY_OUTPUT_PATH}/${lib_name}.a -create
|
||||
#
|
||||
#cp ${work_path}/lib/houyi/lib/ios/libhouyi_score.a ${LIBRARY_OUTPUT_PATH}/
|
||||
#cp ${work_path}/interface/ocr-interface.h ${output_path}
|
||||
#cp ${work_path}/version/release.v ${output_path}
|
||||
#
|
||||
#echo -e "\033[1;36;40mBuild All Target Success At:\n${output_path}\033[0m"
|
||||
#exit 0
|
@ -0,0 +1 @@
|
||||
cmake_policy(SET CMP0077 NEW)
|
@ -0,0 +1,116 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(EXTERNAL_PROJECT_LOG_ARGS
|
||||
LOG_DOWNLOAD 1 # Wrap download in script to log output
|
||||
LOG_UPDATE 1 # Wrap update in script to log output
|
||||
LOG_PATCH 1
|
||||
LOG_CONFIGURE 1# Wrap configure in script to log output
|
||||
LOG_BUILD 1 # Wrap build in script to log output
|
||||
LOG_INSTALL 1
|
||||
LOG_TEST 1 # Wrap test in script to log output
|
||||
LOG_MERGED_STDOUTERR 1
|
||||
LOG_OUTPUT_ON_FAILURE 1
|
||||
)
|
||||
|
||||
if(NOT FASTDEPLOY_INSTALL_DIR)
|
||||
if(ANDROID)
|
||||
FetchContent_Declare(
|
||||
fastdeploy
|
||||
URL https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz
|
||||
URL_HASH MD5=2a15301158e9eb157a4f11283689e7ba
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
)
|
||||
add_definitions("-DUSE_PADDLE_LITE_BAKEND")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
|
||||
else() # Linux
|
||||
FetchContent_Declare(
|
||||
fastdeploy
|
||||
URL https://paddlespeech.bj.bcebos.com/speechx/fastdeploy/fastdeploy-1.0.5-x86_64-onnx.tar.gz
|
||||
URL_HASH MD5=33900d986ea71aa78635e52f0733227c
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
)
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3")
|
||||
endif()
|
||||
|
||||
FetchContent_MakeAvailable(fastdeploy)
|
||||
|
||||
set(FASTDEPLOY_INSTALL_DIR ${fc_patch}/fastdeploy-src)
|
||||
endif()
|
||||
|
||||
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||
|
||||
# fix compiler flags conflict, since fastdeploy using c++11 for project
|
||||
# this line must after `include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)`
|
||||
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
|
||||
|
||||
include_directories(${FASTDEPLOY_INCS})
|
||||
|
||||
# install fastdeploy and dependents lib
|
||||
# install_fastdeploy_libraries(${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
|
||||
# No dynamic libs need to install while using
|
||||
# FastDeploy static lib.
|
||||
if(ANDROID AND WITH_ANDROID_STATIC_LIB)
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(DYN_LIB_SUFFIX "*.so*")
|
||||
if(WIN32)
|
||||
set(DYN_LIB_SUFFIX "*.dll")
|
||||
elseif(APPLE)
|
||||
set(DYN_LIB_SUFFIX "*.dylib*")
|
||||
endif()
|
||||
|
||||
if(FastDeploy_DIR)
|
||||
set(DYN_SEARCH_DIR ${FastDeploy_DIR})
|
||||
elseif(FASTDEPLOY_INSTALL_DIR)
|
||||
set(DYN_SEARCH_DIR ${FASTDEPLOY_INSTALL_DIR})
|
||||
else()
|
||||
message(FATAL_ERROR "Please set FastDeploy_DIR/FASTDEPLOY_INSTALL_DIR before call install_fastdeploy_libraries.")
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE ALL_NEED_DYN_LIBS ${DYN_SEARCH_DIR}/lib/${DYN_LIB_SUFFIX})
|
||||
file(GLOB_RECURSE ALL_DEPS_DYN_LIBS ${DYN_SEARCH_DIR}/third_libs/${DYN_LIB_SUFFIX})
|
||||
|
||||
if(ENABLE_VISION)
|
||||
# OpenCV
|
||||
if(ANDROID)
|
||||
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${DYN_LIB_SUFFIX})
|
||||
else()
|
||||
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_DIR}/../../${DYN_LIB_SUFFIX})
|
||||
endif()
|
||||
|
||||
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_OPENCV_DYN_LIBS})
|
||||
|
||||
if(WIN32)
|
||||
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/x64/vc15/bin/${DYN_LIB_SUFFIX})
|
||||
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
|
||||
elseif(ANDROID AND (NOT WITH_ANDROID_OPENCV_STATIC))
|
||||
file(GLOB OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${ANDROID_ABI}/${DYN_LIB_SUFFIX})
|
||||
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
|
||||
else() # linux/mac
|
||||
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/lib/${DYN_LIB_SUFFIX})
|
||||
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
|
||||
endif()
|
||||
|
||||
# FlyCV
|
||||
if(ENABLE_FLYCV)
|
||||
file(GLOB_RECURSE ALL_FLYCV_DYN_LIBS ${FLYCV_LIB_DIR}/${DYN_LIB_SUFFIX})
|
||||
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_FLYCV_DYN_LIBS})
|
||||
if(ANDROID AND (NOT WITH_ANDROID_FLYCV_STATIC))
|
||||
install(FILES ${ALL_FLYCV_DYN_LIBS} DESTINATION lib)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ENABLE_OPENVINO_BACKEND)
|
||||
# need plugins.xml for openvino backend
|
||||
set(OPENVINO_RUNTIME_BIN_DIR ${OPENVINO_DIR}/bin)
|
||||
file(GLOB OPENVINO_PLUGIN_XML ${OPENVINO_RUNTIME_BIN_DIR}/*.xml)
|
||||
install(FILES ${OPENVINO_PLUGIN_XML} DESTINATION lib)
|
||||
endif()
|
||||
|
||||
# Install other libraries
|
||||
install(FILES ${ALL_NEED_DYN_LIBS} DESTINATION lib)
|
||||
install(FILES ${ALL_DEPS_DYN_LIBS} DESTINATION lib)
|
@ -0,0 +1,14 @@
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
gflags
|
||||
URL https://paddleaudio.bj.bcebos.com/build/gflag-2.2.2.zip
|
||||
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
|
||||
)
|
||||
FetchContent_MakeAvailable(gflags)
|
||||
|
||||
# openfst need
|
||||
include_directories(${gflags_BINARY_DIR}/include)
|
||||
link_directories(${gflags_BINARY_DIR})
|
||||
|
||||
#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib)
|
@ -0,0 +1,35 @@
|
||||
include(FetchContent)
|
||||
|
||||
if(ANDROID)
|
||||
else() # UNIX
|
||||
add_definitions(-DWITH_GLOG)
|
||||
FetchContent_Declare(
|
||||
glog
|
||||
URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip
|
||||
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
|
||||
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
|
||||
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
|
||||
-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
|
||||
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
|
||||
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
|
||||
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
|
||||
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
|
||||
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
-DWITH_GFLAGS=OFF
|
||||
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||
${EXTERNAL_OPTIONAL_ARGS}
|
||||
)
|
||||
set(BUILD_TESTING OFF)
|
||||
FetchContent_MakeAvailable(glog)
|
||||
include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
|
||||
endif()
|
||||
|
||||
|
||||
if(ANDROID)
|
||||
add_library(extern_glog INTERFACE)
|
||||
add_dependencies(extern_glog gflags)
|
||||
else() # UNIX
|
||||
add_library(extern_glog ALIAS glog)
|
||||
add_dependencies(glog gflags)
|
||||
endif()
|
@ -0,0 +1,27 @@
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
if(ANDROID)
|
||||
else() # UNIX
|
||||
FetchContent_Declare(
|
||||
gtest
|
||||
URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip
|
||||
URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
|
||||
)
|
||||
FetchContent_MakeAvailable(gtest)
|
||||
|
||||
include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
if(ANDROID)
|
||||
add_library(extern_gtest INTERFACE)
|
||||
else() # UNIX
|
||||
add_dependencies(gtest gflags extern_glog)
|
||||
add_library(extern_gtest ALIAS gtest)
|
||||
endif()
|
||||
|
||||
if(WITH_TESTING)
|
||||
enable_testing()
|
||||
endif()
|
@ -0,0 +1,42 @@
|
||||
#the pybind11 is from:https://github.com/pybind/pybind11
|
||||
# Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>, All rights reserved.
|
||||
|
||||
SET(PYBIND_ZIP "v2.10.0.zip")
|
||||
SET(LOCAL_PYBIND_ZIP ${FETCHCONTENT_BASE_DIR}/${PYBIND_ZIP})
|
||||
SET(PYBIND_SRC ${FETCHCONTENT_BASE_DIR}/pybind11)
|
||||
SET(DOWNLOAD_URL "https://paddleaudio.bj.bcebos.com/build/v2.10.0.zip")
|
||||
SET(PYBIND_TIMEOUT 600 CACHE STRING "Timeout in seconds when downloading pybind.")
|
||||
|
||||
IF(NOT EXISTS ${LOCAL_PYBIND_ZIP})
|
||||
FILE(DOWNLOAD ${DOWNLOAD_URL}
|
||||
${LOCAL_PYBIND_ZIP}
|
||||
TIMEOUT ${PYBIND_TIMEOUT}
|
||||
STATUS ERR
|
||||
SHOW_PROGRESS
|
||||
)
|
||||
|
||||
IF(ERR EQUAL 0)
|
||||
MESSAGE(STATUS "download pybind success")
|
||||
ELSE()
|
||||
MESSAGE(FATAL_ERROR "download pybind fail")
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
|
||||
IF(NOT EXISTS ${PYBIND_SRC})
|
||||
EXECUTE_PROCESS(
|
||||
COMMAND ${CMAKE_COMMAND} -E tar xfz ${LOCAL_PYBIND_ZIP}
|
||||
WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR}
|
||||
RESULT_VARIABLE tar_result
|
||||
)
|
||||
|
||||
file(RENAME ${FETCHCONTENT_BASE_DIR}/pybind11-2.10.0 ${PYBIND_SRC})
|
||||
|
||||
IF (tar_result MATCHES 0)
|
||||
MESSAGE(STATUS "unzip pybind success")
|
||||
ELSE()
|
||||
MESSAGE(FATAL_ERROR "unzip pybind fail")
|
||||
ENDIF()
|
||||
|
||||
ENDIF()
|
||||
|
||||
include_directories(${PYBIND_SRC}/include)
|
@ -0,0 +1,64 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
function(pps_summary)
|
||||
message(STATUS "")
|
||||
message(STATUS "*************PaddleSpeech Building Summary**********")
|
||||
message(STATUS " PPS_VERSION : ${PPS_VERSION}")
|
||||
message(STATUS " CMake version : ${CMAKE_VERSION}")
|
||||
message(STATUS " CMake command : ${CMAKE_COMMAND}")
|
||||
message(STATUS " UNIX : ${UNIX}")
|
||||
message(STATUS " ANDROID : ${ANDROID}")
|
||||
message(STATUS " System : ${CMAKE_SYSTEM_NAME}")
|
||||
message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}")
|
||||
message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}")
|
||||
message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}")
|
||||
message(STATUS " Build type : ${CMAKE_BUILD_TYPE}")
|
||||
message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}")
|
||||
get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS)
|
||||
message(STATUS " Compile definitions : ${tmp}")
|
||||
message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}")
|
||||
message(STATUS " CMAKE_CURRENT_BINARY_DIR : ${CMAKE_CURRENT_BINARY_DIR}")
|
||||
message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}")
|
||||
message(STATUS " CMAKE_INSTALL_LIBDIR : ${CMAKE_INSTALL_LIBDIR}")
|
||||
message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}")
|
||||
message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}")
|
||||
message(STATUS "")
|
||||
|
||||
message(STATUS " WITH_ASR : ${WITH_ASR}")
|
||||
message(STATUS " WITH_CLS : ${WITH_CLS}")
|
||||
message(STATUS " WITH_VAD : ${WITH_VAD}")
|
||||
message(STATUS " WITH_GPU : ${WITH_GPU}")
|
||||
message(STATUS " WITH_TESTING : ${WITH_TESTING}")
|
||||
message(STATUS " WITH_PROFILING : ${WITH_PROFILING}")
|
||||
message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}")
|
||||
message(STATUS " FASTDEPLOY_INCS : ${FASTDEPLOY_INCS}")
|
||||
message(STATUS " FASTDEPLOY_LIBS : ${FASTDEPLOY_LIBS}")
|
||||
if(WITH_GPU)
|
||||
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
|
||||
endif()
|
||||
|
||||
if(ANDROID)
|
||||
message(STATUS " ANDROID_ABI : ${ANDROID_ABI}")
|
||||
message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}")
|
||||
message(STATUS " ANDROID_NDK : ${ANDROID_NDK}")
|
||||
message(STATUS " ANDROID_NDK_VERSION : ${CMAKE_ANDROID_NDK_VERSION}")
|
||||
endif()
|
||||
if (WITH_ASR)
|
||||
message(STATUS " Python executable : ${PYTHON_EXECUTABLE}")
|
||||
message(STATUS " Python includes : ${PYTHON_INCLUDE_DIR}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
pps_summary()
|
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Detects the OS and sets appropriate variables.
|
||||
# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is
|
||||
# building for, but the host processor name like centos is necessary
|
||||
# in some scenes to distinguish system for customization.
|
||||
#
|
||||
# for instance, protobuf libs path is <install_dir>/lib64
|
||||
# on CentOS, but <install_dir>/lib on other systems.
|
||||
|
||||
if(UNIX AND NOT APPLE)
|
||||
# except apple from nix*Os family
|
||||
set(LINUX TRUE)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
set(HOST_SYSTEM "win32")
|
||||
else()
|
||||
if(APPLE)
|
||||
set(HOST_SYSTEM "macosx")
|
||||
exec_program(
|
||||
sw_vers ARGS
|
||||
-productVersion
|
||||
OUTPUT_VARIABLE HOST_SYSTEM_VERSION)
|
||||
string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}")
|
||||
if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET})
|
||||
# Set cache variable - end user may change this during ccmake or cmake-gui configure.
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET
|
||||
${MACOS_VERSION}
|
||||
CACHE
|
||||
STRING
|
||||
"Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value."
|
||||
)
|
||||
endif()
|
||||
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
|
||||
else()
|
||||
|
||||
if(EXISTS "/etc/issue")
|
||||
file(READ "/etc/issue" LINUX_ISSUE)
|
||||
if(LINUX_ISSUE MATCHES "CentOS")
|
||||
set(HOST_SYSTEM "centos")
|
||||
elseif(LINUX_ISSUE MATCHES "Debian")
|
||||
set(HOST_SYSTEM "debian")
|
||||
elseif(LINUX_ISSUE MATCHES "Ubuntu")
|
||||
set(HOST_SYSTEM "ubuntu")
|
||||
elseif(LINUX_ISSUE MATCHES "Red Hat")
|
||||
set(HOST_SYSTEM "redhat")
|
||||
elseif(LINUX_ISSUE MATCHES "Fedora")
|
||||
set(HOST_SYSTEM "fedora")
|
||||
endif()
|
||||
|
||||
string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION
|
||||
"${LINUX_ISSUE}")
|
||||
endif()
|
||||
|
||||
if(EXISTS "/etc/redhat-release")
|
||||
file(READ "/etc/redhat-release" LINUX_ISSUE)
|
||||
if(LINUX_ISSUE MATCHES "CentOS")
|
||||
set(HOST_SYSTEM "centos")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT HOST_SYSTEM)
|
||||
set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME})
|
||||
endif()
|
||||
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# query number of logical cores
|
||||
cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES)
|
||||
|
||||
mark_as_advanced(HOST_SYSTEM CPU_CORES)
|
||||
|
||||
message(
|
||||
STATUS
|
||||
"Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}")
|
||||
message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores")
|
||||
|
||||
# external dependencies log output
|
||||
set(EXTERNAL_PROJECT_LOG_ARGS
|
||||
LOG_DOWNLOAD
|
||||
0 # Wrap download in script to log output
|
||||
LOG_UPDATE
|
||||
1 # Wrap update in script to log output
|
||||
LOG_CONFIGURE
|
||||
1 # Wrap configure in script to log output
|
||||
LOG_BUILD
|
||||
0 # Wrap build in script to log output
|
||||
LOG_TEST
|
||||
1 # Wrap test in script to log output
|
||||
LOG_INSTALL
|
||||
0 # Wrap install in script to log output
|
||||
)
|
@ -0,0 +1,22 @@
|
||||
project(speechx LANGUAGES CXX)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common)
|
||||
|
||||
add_subdirectory(kaldi)
|
||||
add_subdirectory(common)
|
||||
|
||||
if(WITH_ASR)
|
||||
add_subdirectory(asr)
|
||||
endif()
|
||||
|
||||
if(WITH_CLS)
|
||||
add_subdirectory(audio_classification)
|
||||
endif()
|
||||
|
||||
if(WITH_VAD)
|
||||
add_subdirectory(vad)
|
||||
endif()
|
||||
|
||||
add_subdirectory(codelab)
|
@ -0,0 +1,11 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
project(ASR LANGUAGES CXX)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/server)
|
||||
|
||||
add_subdirectory(decoder)
|
||||
add_subdirectory(recognizer)
|
||||
add_subdirectory(nnet)
|
||||
add_subdirectory(server)
|
@ -0,0 +1,24 @@
|
||||
set(srcs)
|
||||
list(APPEND srcs
|
||||
ctc_prefix_beam_search_decoder.cc
|
||||
ctc_tlg_decoder.cc
|
||||
)
|
||||
|
||||
add_library(decoder STATIC ${srcs})
|
||||
target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
|
||||
|
||||
# test
|
||||
set(TEST_BINS
|
||||
ctc_prefix_beam_search_decoder_main
|
||||
ctc_tlg_decoder_main
|
||||
)
|
||||
|
||||
foreach(bin_name IN LISTS TEST_BINS)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
|
||||
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
|
||||
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
|
||||
endforeach()
|
||||
|
@ -0,0 +1,21 @@
|
||||
set(srcs decodable.cc nnet_producer.cc)
|
||||
|
||||
list(APPEND srcs u2_nnet.cc)
|
||||
if(WITH_ONNX)
|
||||
list(APPEND srcs u2_onnx_nnet.cc)
|
||||
endif()
|
||||
add_library(nnet STATIC ${srcs})
|
||||
target_link_libraries(nnet utils)
|
||||
if(WITH_ONNX)
|
||||
target_link_libraries(nnet ${FASTDEPLOY_LIBS})
|
||||
endif()
|
||||
|
||||
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
|
||||
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||
|
||||
# test bin
|
||||
#set(bin_name u2_nnet_main)
|
||||
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
|
||||
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
|
@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "nnet/nnet_producer.h"
|
||||
|
||||
#include "matrix/kaldi-matrix.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using std::vector;
|
||||
|
||||
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
|
||||
std::shared_ptr<FrontendInterface> frontend,
|
||||
float blank_threshold)
|
||||
: nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) {
|
||||
Reset();
|
||||
}
|
||||
|
||||
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
|
||||
frontend_->Accept(inputs);
|
||||
}
|
||||
|
||||
void NnetProducer::Acceptlikelihood(
|
||||
const kaldi::Matrix<BaseFloat>& likelihood) {
|
||||
std::vector<BaseFloat> prob;
|
||||
prob.resize(likelihood.NumCols());
|
||||
for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) {
|
||||
for (size_t col = 0; col < likelihood.NumCols(); ++col) {
|
||||
prob[col] = likelihood(idx, col);
|
||||
}
|
||||
cache_.push_back(prob);
|
||||
}
|
||||
}
|
||||
|
||||
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
|
||||
bool flag = cache_.pop(nnet_prob);
|
||||
return flag;
|
||||
}
|
||||
|
||||
bool NnetProducer::Compute() {
|
||||
vector<BaseFloat> features;
|
||||
if (frontend_ == NULL || frontend_->Read(&features) == false) {
|
||||
// no feat or frontend_ not init.
|
||||
if (frontend_->IsFinished() == true) {
|
||||
finished_ = true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
CHECK_GE(frontend_->Dim(), 0);
|
||||
VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
|
||||
|
||||
NnetOut out;
|
||||
nnet_->FeedForward(features, frontend_->Dim(), &out);
|
||||
int32& vocab_dim = out.vocab_dim;
|
||||
size_t nframes = out.logprobs.size() / vocab_dim;
|
||||
VLOG(1) << "Forward out " << nframes << " decoder frames.";
|
||||
for (size_t idx = 0; idx < nframes; ++idx) {
|
||||
std::vector<BaseFloat> logprob(
|
||||
out.logprobs.data() + idx * vocab_dim,
|
||||
out.logprobs.data() + (idx + 1) * vocab_dim);
|
||||
// process blank prob
|
||||
float blank_prob = std::exp(logprob[0]);
|
||||
if (blank_prob > blank_threshold_) {
|
||||
last_frame_logprob_ = logprob;
|
||||
is_last_frame_skip_ = true;
|
||||
continue;
|
||||
} else {
|
||||
int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin();
|
||||
if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) {
|
||||
cache_.push_back(last_frame_logprob_);
|
||||
last_max_elem_ = cur_max;
|
||||
}
|
||||
last_max_elem_ = cur_max;
|
||||
is_last_frame_skip_ = false;
|
||||
cache_.push_back(logprob);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void NnetProducer::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||
float reverse_weight,
|
||||
std::vector<float>* rescoring_score) {
|
||||
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,77 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "base/safe_queue.h"
|
||||
#include "frontend/frontend_itf.h"
|
||||
#include "nnet/nnet_itf.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class NnetProducer {
|
||||
public:
|
||||
explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
|
||||
std::shared_ptr<FrontendInterface> frontend,
|
||||
float blank_threshold);
|
||||
// Feed feats or waves
|
||||
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
|
||||
|
||||
void Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood);
|
||||
|
||||
// nnet
|
||||
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
|
||||
|
||||
bool Empty() const { return cache_.empty(); }
|
||||
|
||||
void SetInputFinished() {
|
||||
LOG(INFO) << "set finished";
|
||||
frontend_->SetFinished();
|
||||
}
|
||||
|
||||
// the compute thread exit
|
||||
bool IsFinished() const {
|
||||
return (frontend_->IsFinished() && finished_);
|
||||
}
|
||||
|
||||
~NnetProducer() {}
|
||||
|
||||
void Reset() {
|
||||
if (frontend_ != NULL) frontend_->Reset();
|
||||
if (nnet_ != NULL) nnet_->Reset();
|
||||
cache_.clear();
|
||||
finished_ = false;
|
||||
}
|
||||
|
||||
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||
float reverse_weight,
|
||||
std::vector<float>* rescoring_score);
|
||||
|
||||
bool Compute();
|
||||
private:
|
||||
|
||||
std::shared_ptr<FrontendInterface> frontend_;
|
||||
std::shared_ptr<NnetBase> nnet_;
|
||||
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
|
||||
std::vector<BaseFloat> last_frame_logprob_;
|
||||
bool is_last_frame_skip_ = false;
|
||||
int last_max_elem_ = -1;
|
||||
float blank_threshold_ = 0.0;
|
||||
bool finished_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(NnetProducer);
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,145 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef USE_ONNX
|
||||
#include "nnet/u2_nnet.h"
|
||||
#else
|
||||
#include "nnet/u2_onnx_nnet.h"
|
||||
#endif
|
||||
#include "base/common.h"
|
||||
#include "decoder/param.h"
|
||||
#include "frontend/feature_pipeline.h"
|
||||
#include "frontend/wave-reader.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/nnet_producer.h"
|
||||
#include "nnet/u2_nnet.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test wav rspecifier");
|
||||
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
int sample_rate = FLAGS_sample_rate;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
|
||||
CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
|
||||
CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
|
||||
CHECK_GT(FLAGS_model_path.size(), 0);
|
||||
LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier;
|
||||
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
|
||||
LOG(INFO) << "model path: " << FLAGS_model_path;
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||
FLAGS_wav_rspecifier);
|
||||
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
|
||||
|
||||
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
|
||||
ppspeech::FeaturePipelineOptions feature_opts =
|
||||
ppspeech::FeaturePipelineOptions::InitFromFlags();
|
||||
feature_opts.assembler_opts.fill_zero = false;
|
||||
|
||||
#ifndef USE_ONNX
|
||||
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
|
||||
#else
|
||||
std::shared_ptr<ppspeech::U2OnnxNnet> nnet(new ppspeech::U2OnnxNnet(model_opts));
|
||||
#endif
|
||||
std::shared_ptr<ppspeech::FeaturePipeline> feature_pipeline(
|
||||
new ppspeech::FeaturePipeline(feature_opts));
|
||||
std::shared_ptr<ppspeech::NnetProducer> nnet_producer(
|
||||
new ppspeech::NnetProducer(nnet, feature_pipeline));
|
||||
kaldi::Timer timer;
|
||||
float tot_wav_duration = 0;
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||
LOG(INFO) << "utt: " << utt;
|
||||
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||
double dur = wave_data.Duration();
|
||||
tot_wav_duration += dur;
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
kaldi::Timer timer;
|
||||
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk[i] = waveform(sample_offset + i);
|
||||
}
|
||||
|
||||
nnet_producer->Accept(wav_chunk);
|
||||
if (cur_chunk_size < chunk_sample_size) {
|
||||
nnet_producer->SetInputFinished();
|
||||
}
|
||||
|
||||
// no overlap
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
CHECK(sample_offset == tot_samples);
|
||||
|
||||
std::vector<std::vector<kaldi::BaseFloat>> prob_vec;
|
||||
while (1) {
|
||||
std::vector<kaldi::BaseFloat> logprobs;
|
||||
bool isok = nnet_producer->Read(&logprobs);
|
||||
if (nnet_producer->IsFinished()) break;
|
||||
if (isok == false) continue;
|
||||
prob_vec.push_back(logprobs);
|
||||
}
|
||||
{
|
||||
// writer nnet output
|
||||
kaldi::MatrixIndexT nrow = prob_vec.size();
|
||||
kaldi::MatrixIndexT ncol = prob_vec[0].size();
|
||||
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
|
||||
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
|
||||
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
|
||||
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
|
||||
nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx];
|
||||
}
|
||||
}
|
||||
nnet_out_writer.Write(utt, nnet_out);
|
||||
}
|
||||
nnet_producer->Reset();
|
||||
}
|
||||
|
||||
nnet_producer->Wait();
|
||||
double elapsed = timer.Elapsed();
|
||||
LOG(INFO) << "Program cost:" << elapsed << " sec";
|
||||
|
||||
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,414 @@
|
||||
// Copyright 2022 Horizon Robotics. All Rights Reserved.
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// modified from
|
||||
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.cc
|
||||
|
||||
#include "nnet/u2_onnx_nnet.h"
|
||||
#include "common/base/config.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
void U2OnnxNnet::LoadModel(const std::string& model_dir) {
|
||||
std::string encoder_onnx_path = model_dir + "/encoder.onnx";
|
||||
std::string rescore_onnx_path = model_dir + "/decoder.onnx";
|
||||
std::string ctc_onnx_path = model_dir + "/ctc.onnx";
|
||||
std::string param_path = model_dir + "/param.onnx";
|
||||
// 1. Load sessions
|
||||
try {
|
||||
encoder_ = std::make_shared<fastdeploy::Runtime>();
|
||||
ctc_ = std::make_shared<fastdeploy::Runtime>();
|
||||
rescore_ = std::make_shared<fastdeploy::Runtime>();
|
||||
fastdeploy::RuntimeOption runtime_option;
|
||||
runtime_option.UseOrtBackend();
|
||||
runtime_option.UseCpu();
|
||||
runtime_option.SetCpuThreadNum(1);
|
||||
runtime_option.SetModelPath(encoder_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
|
||||
assert(encoder_->Init(runtime_option));
|
||||
runtime_option.SetModelPath(rescore_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
|
||||
assert(rescore_->Init(runtime_option));
|
||||
runtime_option.SetModelPath(ctc_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
|
||||
assert(ctc_->Init(runtime_option));
|
||||
} catch (std::exception const& e) {
|
||||
LOG(ERROR) << "error when load onnx model: " << e.what();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
Config conf(param_path);
|
||||
encoder_output_size_ = conf.Read("output_size", encoder_output_size_);
|
||||
num_blocks_ = conf.Read("num_blocks", num_blocks_);
|
||||
head_ = conf.Read("head", head_);
|
||||
cnn_module_kernel_ = conf.Read("cnn_module_kernel", cnn_module_kernel_);
|
||||
subsampling_rate_ = conf.Read("subsampling_rate", subsampling_rate_);
|
||||
right_context_ = conf.Read("right_context", right_context_);
|
||||
sos_= conf.Read("sos_symbol", sos_);
|
||||
eos_= conf.Read("eos_symbol", eos_);
|
||||
is_bidecoder_= conf.Read("is_bidirectional_decoder", is_bidecoder_);
|
||||
chunk_size_= conf.Read("chunk_size", chunk_size_);
|
||||
num_left_chunks_ = conf.Read("left_chunks", num_left_chunks_);
|
||||
|
||||
LOG(INFO) << "Onnx Model Info:";
|
||||
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
|
||||
LOG(INFO) << "\tnum_blocks " << num_blocks_;
|
||||
LOG(INFO) << "\thead " << head_;
|
||||
LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_;
|
||||
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
|
||||
LOG(INFO) << "\tright_context " << right_context_;
|
||||
LOG(INFO) << "\tsos " << sos_;
|
||||
LOG(INFO) << "\teos " << eos_;
|
||||
LOG(INFO) << "\tis bidirectional decoder " << is_bidecoder_;
|
||||
LOG(INFO) << "\tchunk_size " << chunk_size_;
|
||||
LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_;
|
||||
|
||||
// 3. Read model nodes
|
||||
LOG(INFO) << "Onnx Encoder:";
|
||||
GetInputOutputInfo(encoder_, &encoder_in_names_, &encoder_out_names_);
|
||||
LOG(INFO) << "Onnx CTC:";
|
||||
GetInputOutputInfo(ctc_, &ctc_in_names_, &ctc_out_names_);
|
||||
LOG(INFO) << "Onnx Rescore:";
|
||||
GetInputOutputInfo(rescore_, &rescore_in_names_, &rescore_out_names_);
|
||||
}
|
||||
|
||||
U2OnnxNnet::U2OnnxNnet(const ModelOptions& opts) : opts_(opts) {
|
||||
LoadModel(opts_.model_path);
|
||||
}
|
||||
|
||||
// shallow copy
|
||||
U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) {
|
||||
// metadatas
|
||||
encoder_output_size_ = other.encoder_output_size_;
|
||||
num_blocks_ = other.num_blocks_;
|
||||
head_ = other.head_;
|
||||
cnn_module_kernel_ = other.cnn_module_kernel_;
|
||||
right_context_ = other.right_context_;
|
||||
subsampling_rate_ = other.subsampling_rate_;
|
||||
sos_ = other.sos_;
|
||||
eos_ = other.eos_;
|
||||
is_bidecoder_ = other.is_bidecoder_;
|
||||
chunk_size_ = other.chunk_size_;
|
||||
num_left_chunks_ = other.num_left_chunks_;
|
||||
offset_ = other.offset_;
|
||||
|
||||
// session
|
||||
encoder_ = other.encoder_;
|
||||
ctc_ = other.ctc_;
|
||||
rescore_ = other.rescore_;
|
||||
|
||||
// node names
|
||||
encoder_in_names_ = other.encoder_in_names_;
|
||||
encoder_out_names_ = other.encoder_out_names_;
|
||||
ctc_in_names_ = other.ctc_in_names_;
|
||||
ctc_out_names_ = other.ctc_out_names_;
|
||||
rescore_in_names_ = other.rescore_in_names_;
|
||||
rescore_out_names_ = other.rescore_out_names_;
|
||||
}
|
||||
|
||||
void U2OnnxNnet::GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
|
||||
std::vector<std::string>* in_names, std::vector<std::string>* out_names) {
|
||||
std::vector<fastdeploy::TensorInfo> inputs_info = runtime->GetInputInfos();
|
||||
(*in_names).resize(inputs_info.size());
|
||||
for (int i = 0; i < inputs_info.size(); ++i){
|
||||
fastdeploy::TensorInfo info = inputs_info[i];
|
||||
|
||||
std::stringstream shape;
|
||||
for(int j = 0; j < info.shape.size(); ++j){
|
||||
shape << info.shape[j];
|
||||
shape << " ";
|
||||
}
|
||||
LOG(INFO) << "\tInput " << i << " : name=" << info.name << " type=" << info.dtype
|
||||
<< " dims=" << shape.str();
|
||||
(*in_names)[i] = info.name;
|
||||
}
|
||||
std::vector<fastdeploy::TensorInfo> outputs_info = runtime->GetOutputInfos();
|
||||
(*out_names).resize(outputs_info.size());
|
||||
for (int i = 0; i < outputs_info.size(); ++i){
|
||||
fastdeploy::TensorInfo info = outputs_info[i];
|
||||
|
||||
std::stringstream shape;
|
||||
for(int j = 0; j < info.shape.size(); ++j){
|
||||
shape << info.shape[j];
|
||||
shape << " ";
|
||||
}
|
||||
LOG(INFO) << "\tOutput " << i << " : name=" << info.name << " type=" << info.dtype
|
||||
<< " dims=" << shape.str();
|
||||
(*out_names)[i] = info.name;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<NnetBase> U2OnnxNnet::Clone() const {
|
||||
auto asr_model = std::make_shared<U2OnnxNnet>(*this);
|
||||
// reset inner state for new decoding
|
||||
asr_model->Reset();
|
||||
return asr_model;
|
||||
}
|
||||
|
||||
void U2OnnxNnet::Reset() {
|
||||
offset_ = 0;
|
||||
encoder_outs_.clear();
|
||||
cached_feats_.clear();
|
||||
// Reset att_cache
|
||||
if (num_left_chunks_ > 0) {
|
||||
int required_cache_size = chunk_size_ * num_left_chunks_;
|
||||
offset_ = required_cache_size;
|
||||
att_cache_.resize(num_blocks_ * head_ * required_cache_size *
|
||||
encoder_output_size_ / head_ * 2,
|
||||
0.0);
|
||||
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, required_cache_size,
|
||||
encoder_output_size_ / head_ * 2};
|
||||
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
|
||||
} else {
|
||||
att_cache_.resize(0, 0.0);
|
||||
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, 0,
|
||||
encoder_output_size_ / head_ * 2};
|
||||
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
|
||||
}
|
||||
|
||||
// Reset cnn_cache
|
||||
cnn_cache_.resize(
|
||||
num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0);
|
||||
const std::vector<int64_t> cnn_cache_shape = {num_blocks_, 1, encoder_output_size_,
|
||||
cnn_module_kernel_ - 1};
|
||||
cnn_cache_ort_.SetExternalData(cnn_cache_shape, fastdeploy::FDDataType::FP32, cnn_cache_.data());
|
||||
}
|
||||
|
||||
void U2OnnxNnet::FeedForward(const std::vector<BaseFloat>& features,
|
||||
const int32& feature_dim,
|
||||
NnetOut* out) {
|
||||
kaldi::Timer timer;
|
||||
|
||||
std::vector<kaldi::BaseFloat> ctc_probs;
|
||||
ForwardEncoderChunkImpl(
|
||||
features, feature_dim, &out->logprobs, &out->vocab_dim);
|
||||
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
|
||||
<< features.size() / feature_dim << " frames.";
|
||||
}
|
||||
|
||||
void U2OnnxNnet::ForwardEncoderChunkImpl(
|
||||
const std::vector<kaldi::BaseFloat>& chunk_feats,
|
||||
const int32& feat_dim,
|
||||
std::vector<kaldi::BaseFloat>* out_prob,
|
||||
int32* vocab_dim) {
|
||||
|
||||
// 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
|
||||
// chunk
|
||||
int num_frames = chunk_feats.size() / feat_dim;
|
||||
VLOG(3) << "num_frames: " << num_frames;
|
||||
VLOG(3) << "feat_dim: " << feat_dim;
|
||||
const int feature_dim = feat_dim;
|
||||
std::vector<float> feats;
|
||||
feats.insert(feats.end(), chunk_feats.begin(), chunk_feats.end());
|
||||
fastdeploy::FDTensor feats_ort;
|
||||
const std::vector<int64_t> feats_shape = {1, num_frames, feature_dim};
|
||||
feats_ort.SetExternalData(feats_shape, fastdeploy::FDDataType::FP32, feats.data());
|
||||
|
||||
// offset
|
||||
int64_t offset_int64 = static_cast<int64_t>(offset_);
|
||||
fastdeploy::FDTensor offset_ort;
|
||||
offset_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &offset_int64);
|
||||
|
||||
// required_cache_size
|
||||
int64_t required_cache_size = chunk_size_ * num_left_chunks_;
|
||||
fastdeploy::FDTensor required_cache_size_ort("");
|
||||
required_cache_size_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &required_cache_size);
|
||||
|
||||
// att_mask
|
||||
fastdeploy::FDTensor att_mask_ort;
|
||||
std::vector<uint8_t> att_mask(required_cache_size + chunk_size_, 1);
|
||||
if (num_left_chunks_ > 0) {
|
||||
int chunk_idx = offset_ / chunk_size_ - num_left_chunks_;
|
||||
if (chunk_idx < num_left_chunks_) {
|
||||
for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
|
||||
att_mask[i] = 0;
|
||||
}
|
||||
}
|
||||
const std::vector<int64_t> att_mask_shape = {1, 1, required_cache_size + chunk_size_};
|
||||
att_mask_ort.SetExternalData(att_mask_shape, fastdeploy::FDDataType::BOOL, reinterpret_cast<bool*>(att_mask.data()));
|
||||
}
|
||||
|
||||
// 2. Encoder chunk forward
|
||||
std::vector<fastdeploy::FDTensor> inputs(encoder_in_names_.size());
|
||||
for (int i = 0; i < encoder_in_names_.size(); ++i) {
|
||||
std::string name = encoder_in_names_[i];
|
||||
if (!strcmp(name.data(), "chunk")) {
|
||||
inputs[i] = std::move(feats_ort);
|
||||
inputs[i].name = "chunk";
|
||||
} else if (!strcmp(name.data(), "offset")) {
|
||||
inputs[i] = std::move(offset_ort);
|
||||
inputs[i].name = "offset";
|
||||
} else if (!strcmp(name.data(), "required_cache_size")) {
|
||||
inputs[i] = std::move(required_cache_size_ort);
|
||||
inputs[i].name = "required_cache_size";
|
||||
} else if (!strcmp(name.data(), "att_cache")) {
|
||||
inputs[i] = std::move(att_cache_ort_);
|
||||
inputs[i].name = "att_cache";
|
||||
} else if (!strcmp(name.data(), "cnn_cache")) {
|
||||
inputs[i] = std::move(cnn_cache_ort_);
|
||||
inputs[i].name = "cnn_cache";
|
||||
} else if (!strcmp(name.data(), "att_mask")) {
|
||||
inputs[i] = std::move(att_mask_ort);
|
||||
inputs[i].name = "att_mask";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<fastdeploy::FDTensor> ort_outputs;
|
||||
assert(encoder_->Infer(inputs, &ort_outputs));
|
||||
|
||||
offset_ += static_cast<int>(ort_outputs[0].shape[1]);
|
||||
att_cache_ort_ = std::move(ort_outputs[1]);
|
||||
cnn_cache_ort_ = std::move(ort_outputs[2]);
|
||||
|
||||
std::vector<fastdeploy::FDTensor> ctc_inputs;
|
||||
ctc_inputs.emplace_back(std::move(ort_outputs[0]));
|
||||
// ctc_inputs[0] = std::move(ort_outputs[0]);
|
||||
ctc_inputs[0].name = ctc_in_names_[0];
|
||||
|
||||
std::vector<fastdeploy::FDTensor> ctc_ort_outputs;
|
||||
assert(ctc_->Infer(ctc_inputs, &ctc_ort_outputs));
|
||||
encoder_outs_.emplace_back(std::move(ctc_inputs[0])); // *****
|
||||
|
||||
float* logp_data = reinterpret_cast<float*>(ctc_ort_outputs[0].Data());
|
||||
|
||||
// Copy to output, (B=1,T,D)
|
||||
std::vector<int64_t> ctc_log_probs_shape = ctc_ort_outputs[0].shape;
|
||||
CHECK_EQ(ctc_log_probs_shape.size(), 3);
|
||||
int B = ctc_log_probs_shape[0];
|
||||
CHECK_EQ(B, 1);
|
||||
int T = ctc_log_probs_shape[1];
|
||||
int D = ctc_log_probs_shape[2];
|
||||
*vocab_dim = D;
|
||||
|
||||
out_prob->resize(T * D);
|
||||
std::memcpy(
|
||||
out_prob->data(), logp_data, T * D * sizeof(kaldi::BaseFloat));
|
||||
return;
|
||||
}
|
||||
|
||||
float U2OnnxNnet::ComputeAttentionScore(const float* prob,
|
||||
const std::vector<int>& hyp, int eos,
|
||||
int decode_out_len) {
|
||||
float score = 0.0f;
|
||||
for (size_t j = 0; j < hyp.size(); ++j) {
|
||||
score += *(prob + j * decode_out_len + hyp[j]);
|
||||
}
|
||||
score += *(prob + hyp.size() * decode_out_len + eos);
|
||||
return score;
|
||||
}
|
||||
|
||||
void U2OnnxNnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||
float reverse_weight,
|
||||
std::vector<float>* rescoring_score) {
|
||||
CHECK(rescoring_score != nullptr);
|
||||
int num_hyps = hyps.size();
|
||||
rescoring_score->resize(num_hyps, 0.0f);
|
||||
|
||||
if (num_hyps == 0) {
|
||||
return;
|
||||
}
|
||||
// No encoder output
|
||||
if (encoder_outs_.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int64_t> hyps_lens;
|
||||
int max_hyps_len = 0;
|
||||
for (size_t i = 0; i < num_hyps; ++i) {
|
||||
int length = hyps[i].size() + 1;
|
||||
max_hyps_len = std::max(length, max_hyps_len);
|
||||
hyps_lens.emplace_back(static_cast<int64_t>(length));
|
||||
}
|
||||
|
||||
std::vector<float> rescore_input;
|
||||
int encoder_len = 0;
|
||||
for (int i = 0; i < encoder_outs_.size(); i++) {
|
||||
float* encoder_outs_data = reinterpret_cast<float*>(encoder_outs_[i].Data());
|
||||
for (int j = 0; j < encoder_outs_[i].Numel(); j++) {
|
||||
rescore_input.emplace_back(encoder_outs_data[j]);
|
||||
}
|
||||
encoder_len += encoder_outs_[i].shape[1];
|
||||
}
|
||||
|
||||
std::vector<int64_t> hyps_pad;
|
||||
|
||||
for (size_t i = 0; i < num_hyps; ++i) {
|
||||
const std::vector<int>& hyp = hyps[i];
|
||||
hyps_pad.emplace_back(sos_);
|
||||
size_t j = 0;
|
||||
for (; j < hyp.size(); ++j) {
|
||||
hyps_pad.emplace_back(hyp[j]);
|
||||
}
|
||||
if (j == max_hyps_len - 1) {
|
||||
continue;
|
||||
}
|
||||
for (; j < max_hyps_len - 1; ++j) {
|
||||
hyps_pad.emplace_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<int64_t> hyps_pad_shape = {num_hyps, max_hyps_len};
|
||||
const std::vector<int64_t> hyps_lens_shape = {num_hyps};
|
||||
const std::vector<int64_t> decode_input_shape = {1, encoder_len, encoder_output_size_};
|
||||
|
||||
fastdeploy::FDTensor hyps_pad_tensor_;
|
||||
hyps_pad_tensor_.SetExternalData(hyps_pad_shape, fastdeploy::FDDataType::INT64, hyps_pad.data());
|
||||
fastdeploy::FDTensor hyps_lens_tensor_;
|
||||
hyps_lens_tensor_.SetExternalData(hyps_lens_shape, fastdeploy::FDDataType::INT64, hyps_lens.data());
|
||||
fastdeploy::FDTensor decode_input_tensor_;
|
||||
decode_input_tensor_.SetExternalData(decode_input_shape, fastdeploy::FDDataType::FP32, rescore_input.data());
|
||||
|
||||
std::vector<fastdeploy::FDTensor> rescore_inputs(3);
|
||||
|
||||
rescore_inputs[0] = std::move(hyps_pad_tensor_);
|
||||
rescore_inputs[0].name = rescore_in_names_[0];
|
||||
rescore_inputs[1] = std::move(hyps_lens_tensor_);
|
||||
rescore_inputs[1].name = rescore_in_names_[1];
|
||||
rescore_inputs[2] = std::move(decode_input_tensor_);
|
||||
rescore_inputs[2].name = rescore_in_names_[2];
|
||||
|
||||
std::vector<fastdeploy::FDTensor> rescore_outputs;
|
||||
assert(rescore_->Infer(rescore_inputs, &rescore_outputs));
|
||||
|
||||
float* decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[0].Data());
|
||||
float* r_decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[1].Data());
|
||||
|
||||
int decode_out_len = rescore_outputs[0].shape[2];
|
||||
|
||||
for (size_t i = 0; i < num_hyps; ++i) {
|
||||
const std::vector<int>& hyp = hyps[i];
|
||||
float score = 0.0f;
|
||||
// left to right decoder score
|
||||
score = ComputeAttentionScore(
|
||||
decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_,
|
||||
decode_out_len);
|
||||
// Optional: Used for right to left score
|
||||
float r_score = 0.0f;
|
||||
if (is_bidecoder_ && reverse_weight > 0) {
|
||||
std::vector<int> r_hyp(hyp.size());
|
||||
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
|
||||
// right to left decoder score
|
||||
r_score = ComputeAttentionScore(
|
||||
r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_,
|
||||
decode_out_len);
|
||||
}
|
||||
// combined left-to-right and right-to-left score
|
||||
(*rescoring_score)[i] =
|
||||
score * (1 - reverse_weight) + r_score * reverse_weight;
|
||||
}
|
||||
}
|
||||
|
||||
void U2OnnxNnet::EncoderOuts(
|
||||
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
|
||||
}
|
||||
|
||||
} //namepace ppspeech
|
@ -0,0 +1,97 @@
|
||||
// Copyright 2022 Horizon Robotics. All Rights Reserved.
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// modified from
|
||||
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "matrix/kaldi-matrix.h"
|
||||
#include "nnet/nnet_itf.h"
|
||||
#include "nnet/u2_nnet.h"
|
||||
|
||||
#include "fastdeploy/runtime.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class U2OnnxNnet : public U2NnetBase {
|
||||
|
||||
public:
|
||||
explicit U2OnnxNnet(const ModelOptions& opts);
|
||||
U2OnnxNnet(const U2OnnxNnet& other);
|
||||
|
||||
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
|
||||
const int32& feature_dim,
|
||||
NnetOut* out) override;
|
||||
|
||||
void Reset() override;
|
||||
|
||||
bool IsLogProb() override { return true; }
|
||||
|
||||
void Dim();
|
||||
|
||||
void LoadModel(const std::string& model_dir);
|
||||
|
||||
std::shared_ptr<NnetBase> Clone() const override;
|
||||
|
||||
void ForwardEncoderChunkImpl(
|
||||
const std::vector<kaldi::BaseFloat>& chunk_feats,
|
||||
const int32& feat_dim,
|
||||
std::vector<kaldi::BaseFloat>* ctc_probs,
|
||||
int32* vocab_dim) override;
|
||||
|
||||
float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
|
||||
int eos, int decode_out_len);
|
||||
|
||||
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
||||
float reverse_weight,
|
||||
std::vector<float>* rescoring_score) override;
|
||||
|
||||
void EncoderOuts(
|
||||
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
|
||||
|
||||
void GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
|
||||
std::vector<std::string>* in_names,
|
||||
std::vector<std::string>* out_names);
|
||||
private:
|
||||
ModelOptions opts_;
|
||||
|
||||
int encoder_output_size_ = 0;
|
||||
int num_blocks_ = 0;
|
||||
int cnn_module_kernel_ = 0;
|
||||
int head_ = 0;
|
||||
|
||||
// sessions
|
||||
std::shared_ptr<fastdeploy::Runtime> encoder_ = nullptr;
|
||||
std::shared_ptr<fastdeploy::Runtime> rescore_ = nullptr;
|
||||
std::shared_ptr<fastdeploy::Runtime> ctc_ = nullptr;
|
||||
|
||||
|
||||
// node names
|
||||
std::vector<std::string> encoder_in_names_, encoder_out_names_;
|
||||
std::vector<std::string> ctc_in_names_, ctc_out_names_;
|
||||
std::vector<std::string> rescore_in_names_, rescore_out_names_;
|
||||
|
||||
// caches
|
||||
fastdeploy::FDTensor att_cache_ort_;
|
||||
fastdeploy::FDTensor cnn_cache_ort_;
|
||||
std::vector<fastdeploy::FDTensor> encoder_outs_;
|
||||
|
||||
std::vector<float> att_cache_;
|
||||
std::vector<float> cnn_cache_;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,26 @@
|
||||
set(srcs)
|
||||
|
||||
list(APPEND srcs
|
||||
recognizer_controller.cc
|
||||
recognizer_controller_impl.cc
|
||||
recognizer_instance.cc
|
||||
recognizer.cc
|
||||
)
|
||||
|
||||
add_library(recognizer STATIC ${srcs})
|
||||
target_link_libraries(recognizer PUBLIC decoder)
|
||||
|
||||
set(TEST_BINS
|
||||
recognizer_batch_main
|
||||
recognizer_batch_main2
|
||||
recognizer_main
|
||||
)
|
||||
|
||||
foreach(bin_name IN LISTS TEST_BINS)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
|
||||
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
|
||||
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
|
||||
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
|
||||
endforeach()
|
@ -0,0 +1,46 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "recognizer/recognizer.h"
|
||||
#include "recognizer/recognizer_instance.h"
|
||||
|
||||
bool InitRecognizer(const std::string& model_file,
|
||||
const std::string& word_symbol_table_file,
|
||||
const std::string& fst_file,
|
||||
int num_instance) {
|
||||
return ppspeech::RecognizerInstance::GetInstance().Init(model_file,
|
||||
word_symbol_table_file,
|
||||
fst_file,
|
||||
num_instance);
|
||||
}
|
||||
|
||||
int GetRecognizerInstanceId() {
|
||||
return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId();
|
||||
}
|
||||
|
||||
void InitDecoder(int instance_id) {
|
||||
return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id);
|
||||
}
|
||||
|
||||
void AcceptData(const std::vector<float>& waves, int instance_id) {
|
||||
return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id);
|
||||
}
|
||||
|
||||
void SetInputFinished(int instance_id) {
|
||||
return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id);
|
||||
}
|
||||
|
||||
std::string GetFinalResult(int instance_id) {
|
||||
return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id);
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
bool InitRecognizer(const std::string& model_file,
|
||||
const std::string& word_symbol_table_file,
|
||||
const std::string& fst_file,
|
||||
int num_instance);
|
||||
int GetRecognizerInstanceId();
|
||||
void InitDecoder(int instance_id);
|
||||
void AcceptData(const std::vector<float>& waves, int instance_id);
|
||||
void SetInputFinished(int instance_id);
|
||||
std::string GetFinalResult(int instance_id);
|
@ -0,0 +1,172 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "common/base/thread_pool.h"
|
||||
#include "common/utils/file_utils.h"
|
||||
#include "common/utils/strings.h"
|
||||
#include "decoder/param.h"
|
||||
#include "frontend/wave-reader.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/u2_nnet.h"
|
||||
#include "recognizer/recognizer_controller.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||
DEFINE_int32(njob, 3, "njob");
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
void SplitUtt(string wavlist_file,
|
||||
vector<vector<string>>* uttlists,
|
||||
vector<vector<string>>* wavlists,
|
||||
int njob) {
|
||||
vector<string> wavlist;
|
||||
wavlists->resize(njob);
|
||||
uttlists->resize(njob);
|
||||
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
|
||||
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||
string utt_str = wavlist[idx];
|
||||
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
|
||||
LOG(INFO) << utt_wav[0];
|
||||
CHECK_EQ(utt_wav.size(), size_t(2));
|
||||
uttlists->at(idx % njob).push_back(utt_wav[0]);
|
||||
wavlists->at(idx % njob).push_back(utt_wav[1]);
|
||||
}
|
||||
}
|
||||
|
||||
void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
|
||||
std::vector<string> wavlist,
|
||||
std::vector<string> uttlist,
|
||||
std::vector<string>* results) {
|
||||
int32 num_done = 0, num_err = 0;
|
||||
double tot_wav_duration = 0.0;
|
||||
double tot_attention_rescore_time = 0.0;
|
||||
double tot_decode_time = 0.0;
|
||||
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
|
||||
if (wavlist.empty()) return;
|
||||
|
||||
results->reserve(wavlist.size());
|
||||
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||
std::string utt = uttlist[idx];
|
||||
std::string wav_file = wavlist[idx];
|
||||
std::ifstream infile;
|
||||
infile.open(wav_file, std::ifstream::in);
|
||||
kaldi::WaveData wave_data;
|
||||
wave_data.Read(infile);
|
||||
int32 recog_id = -1;
|
||||
while (recog_id == -1) {
|
||||
recog_id = recognizer_controller->GetRecognizerInstanceId();
|
||||
}
|
||||
recognizer_controller->InitDecoder(recog_id);
|
||||
LOG(INFO) << "utt: " << utt;
|
||||
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||
double dur = wave_data.Duration();
|
||||
tot_wav_duration += dur;
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
kaldi::Timer local_timer;
|
||||
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk[i] = waveform(sample_offset + i);
|
||||
}
|
||||
|
||||
recognizer_controller->Accept(wav_chunk, recog_id);
|
||||
// no overlap
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
recognizer_controller->SetInputFinished(recog_id);
|
||||
CHECK(sample_offset == tot_samples);
|
||||
std::string result = recognizer_controller->GetFinalResult(recog_id);
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
LOG(INFO) << " the result of " << utt << " is empty";
|
||||
result = " ";
|
||||
}
|
||||
|
||||
tot_decode_time += local_timer.Elapsed();
|
||||
LOG(INFO) << utt << " " << result;
|
||||
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
|
||||
<< " cost: " << local_timer.Elapsed();
|
||||
|
||||
results->push_back(result);
|
||||
++num_done;
|
||||
}
|
||||
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
||||
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
||||
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
|
||||
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
int sample_rate = FLAGS_sample_rate;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
int njob = FLAGS_njob;
|
||||
LOG(INFO) << "sr: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
ppspeech::RecognizerResource resource =
|
||||
ppspeech::RecognizerResource::InitFromFlags();
|
||||
ppspeech::RecognizerController recognizer_controller(njob, resource);
|
||||
ThreadPool threadpool(njob);
|
||||
vector<vector<string>> wavlist;
|
||||
vector<vector<string>> uttlist;
|
||||
vector<vector<string>> resultlist(njob);
|
||||
vector<std::future<void>> futurelist;
|
||||
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
|
||||
for (size_t i = 0; i < njob; ++i) {
|
||||
std::future<void> f = threadpool.enqueue(recognizer_func,
|
||||
&recognizer_controller,
|
||||
wavlist[i],
|
||||
uttlist[i],
|
||||
&resultlist[i]);
|
||||
futurelist.push_back(std::move(f));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < njob; ++i) {
|
||||
futurelist[i].get();
|
||||
}
|
||||
|
||||
for (size_t idx = 0; idx < njob; ++idx) {
|
||||
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
|
||||
string utt = uttlist[idx][utt_idx];
|
||||
string result = resultlist[idx][utt_idx];
|
||||
result_writer.Write(utt, result);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,168 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "common/base/thread_pool.h"
|
||||
#include "common/utils/file_utils.h"
|
||||
#include "common/utils/strings.h"
|
||||
#include "decoder/param.h"
|
||||
#include "frontend/wave-reader.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/u2_nnet.h"
|
||||
#include "recognizer/recognizer.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
DEFINE_int32(sample_rate, 16000, "sample rate");
|
||||
DEFINE_int32(njob, 3, "njob");
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
void SplitUtt(string wavlist_file,
|
||||
vector<vector<string>>* uttlists,
|
||||
vector<vector<string>>* wavlists,
|
||||
int njob) {
|
||||
vector<string> wavlist;
|
||||
wavlists->resize(njob);
|
||||
uttlists->resize(njob);
|
||||
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
|
||||
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||
string utt_str = wavlist[idx];
|
||||
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
|
||||
LOG(INFO) << utt_wav[0];
|
||||
CHECK_EQ(utt_wav.size(), size_t(2));
|
||||
uttlists->at(idx % njob).push_back(utt_wav[0]);
|
||||
wavlists->at(idx % njob).push_back(utt_wav[1]);
|
||||
}
|
||||
}
|
||||
|
||||
void recognizer_func(std::vector<string> wavlist,
|
||||
std::vector<string> uttlist,
|
||||
std::vector<string>* results) {
|
||||
int32 num_done = 0, num_err = 0;
|
||||
double tot_wav_duration = 0.0;
|
||||
double tot_attention_rescore_time = 0.0;
|
||||
double tot_decode_time = 0.0;
|
||||
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
|
||||
if (wavlist.empty()) return;
|
||||
|
||||
results->reserve(wavlist.size());
|
||||
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
||||
std::string utt = uttlist[idx];
|
||||
std::string wav_file = wavlist[idx];
|
||||
std::ifstream infile;
|
||||
infile.open(wav_file, std::ifstream::in);
|
||||
kaldi::WaveData wave_data;
|
||||
wave_data.Read(infile);
|
||||
int32 recog_id = -1;
|
||||
while (recog_id == -1) {
|
||||
recog_id = GetRecognizerInstanceId();
|
||||
}
|
||||
InitDecoder(recog_id);
|
||||
LOG(INFO) << "utt: " << utt;
|
||||
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
|
||||
double dur = wave_data.Duration();
|
||||
tot_wav_duration += dur;
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
kaldi::Timer local_timer;
|
||||
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk[i] = waveform(sample_offset + i);
|
||||
}
|
||||
|
||||
AcceptData(wav_chunk, recog_id);
|
||||
// no overlap
|
||||
sample_offset += cur_chunk_size;
|
||||
}
|
||||
SetInputFinished(recog_id);
|
||||
CHECK(sample_offset == tot_samples);
|
||||
std::string result = GetFinalResult(recog_id);
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
LOG(INFO) << " the result of " << utt << " is empty";
|
||||
result = " ";
|
||||
}
|
||||
|
||||
tot_decode_time += local_timer.Elapsed();
|
||||
LOG(INFO) << utt << " " << result;
|
||||
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
|
||||
<< " cost: " << local_timer.Elapsed();
|
||||
|
||||
results->push_back(result);
|
||||
++num_done;
|
||||
}
|
||||
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
||||
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
||||
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
|
||||
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
|
||||
int sample_rate = FLAGS_sample_rate;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
int njob = FLAGS_njob;
|
||||
LOG(INFO) << "sr: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob);
|
||||
ThreadPool threadpool(njob);
|
||||
vector<vector<string>> wavlist;
|
||||
vector<vector<string>> uttlist;
|
||||
vector<vector<string>> resultlist(njob);
|
||||
vector<std::future<void>> futurelist;
|
||||
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
|
||||
for (size_t i = 0; i < njob; ++i) {
|
||||
std::future<void> f = threadpool.enqueue(recognizer_func,
|
||||
wavlist[i],
|
||||
uttlist[i],
|
||||
&resultlist[i]);
|
||||
futurelist.push_back(std::move(f));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < njob; ++i) {
|
||||
futurelist[i].get();
|
||||
}
|
||||
|
||||
for (size_t idx = 0; idx < njob; ++idx) {
|
||||
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
|
||||
string utt = uttlist[idx][utt_idx];
|
||||
string result = resultlist[idx][utt_idx];
|
||||
result_writer.Write(utt, result);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,70 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "recognizer/recognizer_controller.h"
|
||||
#include "nnet/u2_nnet.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) {
|
||||
recognizer_workers.resize(num_worker);
|
||||
for (size_t i = 0; i < num_worker; ++i) {
|
||||
recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource));
|
||||
waiting_workers.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
int RecognizerController::GetRecognizerInstanceId() {
|
||||
if (waiting_workers.empty()) {
|
||||
return -1;
|
||||
}
|
||||
int idx = -1;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
idx = waiting_workers.front();
|
||||
waiting_workers.pop();
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
RecognizerController::~RecognizerController() {
|
||||
for (size_t i = 0; i < recognizer_workers.size(); ++i) {
|
||||
recognizer_workers[i]->WaitFinished();
|
||||
}
|
||||
}
|
||||
|
||||
void RecognizerController::InitDecoder(int idx) {
|
||||
recognizer_workers[idx]->InitDecoder();
|
||||
}
|
||||
|
||||
std::string RecognizerController::GetFinalResult(int idx) {
|
||||
recognizer_workers[idx]->WaitDecoderFinished();
|
||||
recognizer_workers[idx]->AttentionRescoring();
|
||||
std::string result = recognizer_workers[idx]->GetFinalResult();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
waiting_workers.push(idx);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void RecognizerController::Accept(std::vector<float> data, int idx) {
|
||||
recognizer_workers[idx]->Accept(data);
|
||||
}
|
||||
|
||||
void RecognizerController::SetInputFinished(int idx) {
|
||||
recognizer_workers[idx]->SetInputFinished();
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
|
||||
#include "recognizer/recognizer_controller_impl.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class RecognizerController {
|
||||
public:
|
||||
explicit RecognizerController(int num_worker, RecognizerResource resource);
|
||||
~RecognizerController();
|
||||
int GetRecognizerInstanceId();
|
||||
void InitDecoder(int idx);
|
||||
void Accept(std::vector<float> data, int idx);
|
||||
void SetInputFinished(int idx);
|
||||
std::string GetFinalResult(int idx);
|
||||
|
||||
private:
|
||||
std::queue<int> waiting_workers;
|
||||
std::mutex mutex_;
|
||||
std::vector<std::unique_ptr<ppspeech::RecognizerControllerImpl>> recognizer_workers;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(RecognizerController);
|
||||
};
|
||||
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "decoder/common.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fst/symbol-table.h"
|
||||
#include "nnet/u2_nnet.h"
|
||||
#include "nnet/nnet_producer.h"
|
||||
#ifdef USE_ONNX
|
||||
#include "nnet/u2_onnx_nnet.h"
|
||||
#endif
|
||||
#include "nnet/decodable.h"
|
||||
#include "recognizer/recognizer_resource.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class RecognizerControllerImpl {
|
||||
public:
|
||||
explicit RecognizerControllerImpl(const RecognizerResource& resource);
|
||||
~RecognizerControllerImpl();
|
||||
void Accept(std::vector<float> data);
|
||||
void InitDecoder();
|
||||
void SetInputFinished();
|
||||
std::string GetFinalResult();
|
||||
std::string GetPartialResult();
|
||||
void Rescoring();
|
||||
void Reset();
|
||||
void WaitDecoderFinished();
|
||||
void WaitFinished();
|
||||
void AttentionRescoring();
|
||||
bool DecodedSomething() const {
|
||||
return !result_.empty() && !result_[0].sentence.empty();
|
||||
}
|
||||
int FrameShiftInMs() const {
|
||||
return 1; //todo
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
static void RunNnetEvaluation(RecognizerControllerImpl* me);
|
||||
void RunNnetEvaluationInternal();
|
||||
static void RunDecoder(RecognizerControllerImpl* me);
|
||||
void RunDecoderInternal();
|
||||
void UpdateResult(bool finish = false);
|
||||
|
||||
std::shared_ptr<Decodable> decodable_;
|
||||
std::unique_ptr<DecoderBase> decoder_;
|
||||
std::shared_ptr<NnetProducer> nnet_producer_;
|
||||
|
||||
// e2e unit symbol table
|
||||
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
|
||||
std::vector<DecodeResult> result_;
|
||||
|
||||
RecognizerResource opts_;
|
||||
bool abort_ = false;
|
||||
// global decoded frame offset
|
||||
int global_frame_offset_;
|
||||
// cur decoded frame num
|
||||
int num_frames_;
|
||||
// timestamp gap between words in a sentence
|
||||
const int time_stamp_gap_ = 100;
|
||||
bool input_finished_;
|
||||
|
||||
std::mutex nnet_mutex_;
|
||||
std::mutex decoder_mutex_;
|
||||
std::condition_variable nnet_condition_;
|
||||
std::condition_variable decoder_condition_;
|
||||
std::thread nnet_thread_;
|
||||
std::thread decoder_thread_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(RecognizerControllerImpl);
|
||||
};
|
||||
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "recognizer/recognizer_instance.h"
|
||||
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
RecognizerInstance& RecognizerInstance::GetInstance() {
|
||||
static RecognizerInstance instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool RecognizerInstance::Init(const std::string& model_file,
|
||||
const std::string& word_symbol_table_file,
|
||||
const std::string& fst_file,
|
||||
int num_instance) {
|
||||
RecognizerResource resource = RecognizerResource::InitFromFlags();
|
||||
resource.model_opts.model_path = model_file;
|
||||
//resource.vocab_path = word_symbol_table_file;
|
||||
if (!fst_file.empty()) {
|
||||
resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file;
|
||||
resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file;
|
||||
} else {
|
||||
resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table =
|
||||
word_symbol_table_file;
|
||||
}
|
||||
recognizer_controller_ = std::make_unique<RecognizerController>(num_instance, resource);
|
||||
return true;
|
||||
}
|
||||
|
||||
void RecognizerInstance::InitDecoder(int idx) {
|
||||
recognizer_controller_->InitDecoder(idx);
|
||||
return;
|
||||
}
|
||||
|
||||
int RecognizerInstance::GetRecognizerInstanceId() {
|
||||
return recognizer_controller_->GetRecognizerInstanceId();
|
||||
}
|
||||
|
||||
void RecognizerInstance::Accept(const std::vector<float>& waves, int idx) const {
|
||||
recognizer_controller_->Accept(waves, idx);
|
||||
return;
|
||||
}
|
||||
|
||||
void RecognizerInstance::SetInputFinished(int idx) const {
|
||||
recognizer_controller_->SetInputFinished(idx);
|
||||
return;
|
||||
}
|
||||
|
||||
std::string RecognizerInstance::GetResult(int idx) const {
|
||||
return recognizer_controller_->GetFinalResult(idx);
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "recognizer/recognizer_controller.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class RecognizerInstance {
|
||||
public:
|
||||
static RecognizerInstance& GetInstance();
|
||||
RecognizerInstance() {}
|
||||
~RecognizerInstance() {}
|
||||
bool Init(const std::string& model_file,
|
||||
const std::string& word_symbol_table_file,
|
||||
const std::string& fst_file,
|
||||
int num_instance);
|
||||
int GetRecognizerInstanceId();
|
||||
void InitDecoder(int idx);
|
||||
void Accept(const std::vector<float>& waves, int idx) const;
|
||||
void SetInputFinished(int idx) const;
|
||||
std::string GetResult(int idx) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<RecognizerController> recognizer_controller_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1 @@
|
||||
#add_subdirectory(websocket)
|
@ -0,0 +1,3 @@
|
||||
# add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND")
|
||||
add_definitions("-DUSE_ORT_BACKEND")
|
||||
add_subdirectory(nnet)
|
@ -0,0 +1,11 @@
|
||||
set(srcs
|
||||
panns_nnet.cc
|
||||
panns_interface.cc
|
||||
)
|
||||
|
||||
add_library(cls SHARED ${srcs})
|
||||
target_link_libraries(cls PRIVATE ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils )
|
||||
|
||||
set(bin_name panns_nnet_main)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_link_libraries(${bin_name} gflags glog cls)
|
@ -0,0 +1,79 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "audio_classification/nnet/panns_interface.h"
|
||||
|
||||
#include "audio_classification/nnet/panns_nnet.h"
|
||||
#include "common/base/config.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
void* ClsCreateInstance(const char* conf_path) {
|
||||
Config conf(conf_path);
|
||||
// cls init
|
||||
ppspeech::ClsNnetConf cls_nnet_conf;
|
||||
cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true);
|
||||
cls_nnet_conf.wav_normal_type_ =
|
||||
conf.Read("wav_normal_type", std::string("linear"));
|
||||
cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0);
|
||||
cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string(""));
|
||||
cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string(""));
|
||||
cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string(""));
|
||||
cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12);
|
||||
cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000);
|
||||
cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32);
|
||||
cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10);
|
||||
cls_nnet_conf.num_bins = conf.Read("num_bins", 64);
|
||||
cls_nnet_conf.low_freq = conf.Read("low_freq", 50);
|
||||
cls_nnet_conf.high_freq = conf.Read("high_freq", 14000);
|
||||
cls_nnet_conf.dither = conf.Read("dither", 0.0);
|
||||
|
||||
ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet();
|
||||
int ret = cls_model->Init(cls_nnet_conf);
|
||||
return static_cast<void*>(cls_model);
|
||||
}
|
||||
|
||||
int ClsDestroyInstance(void* instance) {
|
||||
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
||||
if (cls_model != NULL) {
|
||||
delete cls_model;
|
||||
cls_model = NULL;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ClsFeedForward(void* instance,
|
||||
const char* wav_path,
|
||||
int topk,
|
||||
char* result,
|
||||
int result_max_len) {
|
||||
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
||||
if (cls_model == NULL) {
|
||||
printf("instance is null\n");
|
||||
return -1;
|
||||
}
|
||||
int ret = cls_model->Forward(wav_path, topk, result, result_max_len);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ClsReset(void* instance) {
|
||||
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
||||
if (cls_model == NULL) {
|
||||
printf("instance is null\n");
|
||||
return -1;
|
||||
}
|
||||
cls_model->Reset();
|
||||
return 0;
|
||||
}
|
||||
} // namespace ppspeech
|
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
void* ClsCreateInstance(const char* conf_path);
|
||||
int ClsDestroyInstance(void* instance);
|
||||
int ClsFeedForward(void* instance,
|
||||
const char* wav_path,
|
||||
int topk,
|
||||
char* result,
|
||||
int result_max_len);
|
||||
int ClsReset(void* instance);
|
||||
} // namespace ppspeech
|
@ -0,0 +1,227 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "audio_classification/nnet/panns_nnet.h"
|
||||
#ifdef WITH_PROFILING
|
||||
#include "kaldi/base/timer.h"
|
||||
#endif
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
ClsNnet::ClsNnet() {
|
||||
// wav_reader_ = NULL;
|
||||
runtime_ = NULL;
|
||||
}
|
||||
|
||||
void ClsNnet::Reset() {
|
||||
// wav_reader_->Clear();
|
||||
ss_.str("");
|
||||
}
|
||||
|
||||
int ClsNnet::Init(const ClsNnetConf& conf) {
|
||||
conf_ = conf;
|
||||
// init fbank opts
|
||||
fbank_opts_.frame_opts.samp_freq = conf.samp_freq;
|
||||
fbank_opts_.frame_opts.frame_length_ms = conf.frame_length_ms;
|
||||
fbank_opts_.frame_opts.frame_shift_ms = conf.frame_shift_ms;
|
||||
fbank_opts_.mel_opts.num_bins = conf.num_bins;
|
||||
fbank_opts_.mel_opts.low_freq = conf.low_freq;
|
||||
fbank_opts_.mel_opts.high_freq = conf.high_freq;
|
||||
fbank_opts_.frame_opts.dither = conf.dither;
|
||||
fbank_opts_.use_log_fbank = false;
|
||||
|
||||
// init dict
|
||||
if (conf.dict_file_path_ != "") {
|
||||
ReadFileToVector(conf.dict_file_path_, &dict_);
|
||||
}
|
||||
|
||||
// init model
|
||||
fastdeploy::RuntimeOption runtime_option;
|
||||
|
||||
#ifdef USE_PADDLE_INFERENCE_BACKEND
|
||||
runtime_option.SetModelPath(conf.model_file_path_,
|
||||
conf.param_file_path_,
|
||||
fastdeploy::ModelFormat::PADDLE);
|
||||
runtime_option.UsePaddleInferBackend();
|
||||
#elif defined(USE_ORT_BACKEND)
|
||||
runtime_option.SetModelPath(
|
||||
conf.model_file_path_, "", fastdeploy::ModelFormat::ONNX); // onnx
|
||||
runtime_option.UseOrtBackend(); // onnx
|
||||
#elif defined(USE_PADDLE_LITE_BACKEND)
|
||||
runtime_option.SetModelPath(conf.model_file_path_,
|
||||
conf.param_file_path_,
|
||||
fastdeploy::ModelFormat::PADDLE);
|
||||
runtime_option.UseLiteBackend();
|
||||
#endif
|
||||
|
||||
runtime_option.SetCpuThreadNum(conf.num_cpu_thread_);
|
||||
// runtime_option.DeletePaddleBackendPass("simplify_with_basic_ops_pass");
|
||||
runtime_ = std::unique_ptr<fastdeploy::Runtime>(new fastdeploy::Runtime());
|
||||
if (!runtime_->Init(runtime_option)) {
|
||||
std::cerr << "--- Init FastDeploy Runitme Failed! "
|
||||
<< "\n--- Model: " << conf.model_file_path_ << std::endl;
|
||||
return -1;
|
||||
} else {
|
||||
std::cout << "--- Init FastDeploy Runitme Done! "
|
||||
<< "\n--- Model: " << conf.model_file_path_ << std::endl;
|
||||
}
|
||||
|
||||
Reset();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ClsNnet::Forward(const char* wav_path,
|
||||
int topk,
|
||||
char* result,
|
||||
int result_max_len) {
|
||||
#ifdef WITH_PROFILING
|
||||
kaldi::Timer timer;
|
||||
timer.Reset();
|
||||
#endif
|
||||
// read wav
|
||||
std::ifstream infile(wav_path, std::ifstream::in);
|
||||
kaldi::WaveData wave_data;
|
||||
wave_data.Read(infile);
|
||||
int32 this_channel = 0;
|
||||
kaldi::Matrix<float> wavform_kaldi = wave_data.Data();
|
||||
// only get channel 0
|
||||
int wavform_len = wavform_kaldi.NumCols();
|
||||
std::vector<float> wavform(wavform_kaldi.Data(),
|
||||
wavform_kaldi.Data() + wavform_len);
|
||||
WaveformFloatNormal(&wavform);
|
||||
WaveformNormal(&wavform,
|
||||
conf_.wav_normal_,
|
||||
conf_.wav_normal_type_,
|
||||
conf_.wav_norm_mul_factor_);
|
||||
#ifdef PPS_DEBUG
|
||||
{
|
||||
std::ofstream fp("cls.wavform", std::ios::out);
|
||||
for (int i = 0; i < wavform.size(); ++i) {
|
||||
fp << std::setprecision(18) << wavform[i] << " ";
|
||||
}
|
||||
fp << "\n";
|
||||
}
|
||||
#endif
|
||||
#ifdef WITH_PROFILING
|
||||
printf("wav read consume: %fs\n", timer.Elapsed());
|
||||
#endif
|
||||
|
||||
#ifdef WITH_PROFILING
|
||||
timer.Reset();
|
||||
#endif
|
||||
|
||||
std::vector<float> feats;
|
||||
std::unique_ptr<ppspeech::FrontendInterface> data_source(
|
||||
new ppspeech::DataCache());
|
||||
ppspeech::Fbank fbank(fbank_opts_, std::move(data_source));
|
||||
fbank.Accept(wavform);
|
||||
fbank.SetFinished();
|
||||
fbank.Read(&feats);
|
||||
|
||||
int feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
int num_frames = feats.size() / feat_dim;
|
||||
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
for (int j = 0; j < feat_dim; ++j) {
|
||||
feats[i * feat_dim + j] = PowerTodb(feats[i * feat_dim + j]);
|
||||
}
|
||||
}
|
||||
#ifdef PPS_DEBUG
|
||||
{
|
||||
std::ofstream fp("cls.feat", std::ios::out);
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
for (int j = 0; j < feat_dim; ++j) {
|
||||
fp << std::setprecision(18) << feats[i * feat_dim + j] << " ";
|
||||
}
|
||||
fp << "\n";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef WITH_PROFILING
|
||||
printf("extract fbank consume: %fs\n", timer.Elapsed());
|
||||
#endif
|
||||
|
||||
// infer
|
||||
std::vector<float> model_out;
|
||||
#ifdef WITH_PROFILING
|
||||
timer.Reset();
|
||||
#endif
|
||||
ModelForward(feats.data(), num_frames, feat_dim, &model_out);
|
||||
#ifdef WITH_PROFILING
|
||||
printf("fast deploy infer consume: %fs\n", timer.Elapsed());
|
||||
#endif
|
||||
#ifdef PPS_DEBUG
|
||||
{
|
||||
std::ofstream fp("cls.logits", std::ios::out);
|
||||
for (int i = 0; i < model_out.size(); ++i) {
|
||||
fp << std::setprecision(18) << model_out[i] << "\n";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// construct result str
|
||||
ss_ << "{";
|
||||
GetTopkResult(topk, model_out);
|
||||
ss_ << "}";
|
||||
|
||||
if (result_max_len <= ss_.str().size()) {
|
||||
printf("result_max_len is short than result len\n");
|
||||
}
|
||||
snprintf(result, result_max_len, "%s", ss_.str().c_str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ClsNnet::ModelForward(float* features,
|
||||
const int num_frames,
|
||||
const int feat_dim,
|
||||
std::vector<float>* model_out) {
|
||||
// init input tensor shape
|
||||
fastdeploy::TensorInfo info = runtime_->GetInputInfo(0);
|
||||
info.shape = {1, num_frames, feat_dim};
|
||||
|
||||
std::vector<fastdeploy::FDTensor> input_tensors(1);
|
||||
std::vector<fastdeploy::FDTensor> output_tensors(1);
|
||||
|
||||
input_tensors[0].SetExternalData({1, num_frames, feat_dim},
|
||||
fastdeploy::FDDataType::FP32,
|
||||
static_cast<void*>(features));
|
||||
|
||||
// get input name
|
||||
input_tensors[0].name = info.name;
|
||||
|
||||
runtime_->Infer(input_tensors, &output_tensors);
|
||||
|
||||
// output_tensors[0].PrintInfo();
|
||||
std::vector<int64_t> output_shape = output_tensors[0].Shape();
|
||||
model_out->resize(output_shape[0] * output_shape[1]);
|
||||
memcpy(static_cast<void*>(model_out->data()),
|
||||
output_tensors[0].Data(),
|
||||
output_shape[0] * output_shape[1] * sizeof(float));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ClsNnet::GetTopkResult(int k, const std::vector<float>& model_out) {
|
||||
std::vector<float> values;
|
||||
std::vector<int> indics;
|
||||
TopK(model_out, k, &values, &indics);
|
||||
for (int i = 0; i < k; ++i) {
|
||||
if (i != 0) {
|
||||
ss_ << ",";
|
||||
}
|
||||
ss_ << "\"" << dict_[indics[i]] << "\":\"" << values[i] << "\"";
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,74 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/frontend/data_cache.h"
|
||||
#include "common/frontend/fbank.h"
|
||||
#include "common/frontend/feature-fbank.h"
|
||||
#include "common/frontend/frontend_itf.h"
|
||||
#include "common/frontend/wave-reader.h"
|
||||
#include "common/utils/audio_process.h"
|
||||
#include "common/utils/file_utils.h"
|
||||
#include "fastdeploy/runtime.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
|
||||
namespace ppspeech {
|
||||
struct ClsNnetConf {
|
||||
// wav
|
||||
bool wav_normal_;
|
||||
std::string wav_normal_type_;
|
||||
float wav_norm_mul_factor_;
|
||||
// model
|
||||
std::string model_file_path_;
|
||||
std::string param_file_path_;
|
||||
std::string dict_file_path_;
|
||||
int num_cpu_thread_;
|
||||
// fbank
|
||||
float samp_freq;
|
||||
float frame_length_ms;
|
||||
float frame_shift_ms;
|
||||
int num_bins;
|
||||
float low_freq;
|
||||
float high_freq;
|
||||
float dither;
|
||||
};
|
||||
|
||||
class ClsNnet {
|
||||
public:
|
||||
ClsNnet();
|
||||
int Init(const ClsNnetConf& conf);
|
||||
int Forward(const char* wav_path,
|
||||
int topk,
|
||||
char* result,
|
||||
int result_max_len);
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
int ModelForward(float* features,
|
||||
const int num_frames,
|
||||
const int feat_dim,
|
||||
std::vector<float>* model_out);
|
||||
int ModelForwardStream(std::vector<float>* feats);
|
||||
int GetTopkResult(int k, const std::vector<float>& model_out);
|
||||
|
||||
ClsNnetConf conf_;
|
||||
knf::FbankOptions fbank_opts_;
|
||||
std::unique_ptr<fastdeploy::Runtime> runtime_;
|
||||
std::vector<std::string> dict_;
|
||||
std::stringstream ss_;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "gflags/gflags.h"
|
||||
#include "glog/logging.h"
|
||||
#include "audio_classification/nnet/panns_interface.h"
|
||||
|
||||
DEFINE_string(conf_path, "", "config path");
|
||||
DEFINE_string(scp_path, "", "wav scp path");
|
||||
DEFINE_string(topk, "", "print topk results");
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::SetUsageMessage("Usage:");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
google::InstallFailureSignalHandler();
|
||||
FLAGS_logtostderr = 1;
|
||||
CHECK_GT(FLAGS_conf_path.size(), 0);
|
||||
CHECK_GT(FLAGS_scp_path.size(), 0);
|
||||
CHECK_GT(FLAGS_topk.size(), 0);
|
||||
void* instance = ppspeech::ClsCreateInstance(FLAGS_conf_path.c_str());
|
||||
int ret = 0;
|
||||
// read wav
|
||||
std::ifstream ifs(FLAGS_scp_path);
|
||||
std::string line = "";
|
||||
int topk = std::atoi(FLAGS_topk.c_str());
|
||||
while (getline(ifs, line)) {
|
||||
// read wav
|
||||
char result[1024] = {0};
|
||||
ret = ppspeech::ClsFeedForward(
|
||||
instance, line.c_str(), topk, result, 1024);
|
||||
printf("%s %s\n", line.c_str(), result);
|
||||
ret = ppspeech::ClsReset(instance);
|
||||
}
|
||||
ret = ppspeech::ClsDestroyInstance(instance);
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
if(ANDROID)
|
||||
else() #Unix
|
||||
add_subdirectory(glog)
|
||||
endif()
|
@ -1,8 +1,8 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc)
|
||||
target_link_libraries(glog_main glog)
|
||||
target_link_libraries(glog_main extern_glog)
|
||||
|
||||
|
||||
add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc)
|
||||
target_link_libraries(glog_logtostderr_main glog)
|
||||
target_link_libraries(glog_logtostderr_main extern_glog)
|
@ -0,0 +1,19 @@
|
||||
include_directories(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../
|
||||
)
|
||||
add_subdirectory(base)
|
||||
add_subdirectory(utils)
|
||||
add_subdirectory(matrix)
|
||||
|
||||
include_directories(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/frontend
|
||||
)
|
||||
add_subdirectory(frontend)
|
||||
|
||||
add_library(common INTERFACE)
|
||||
target_link_libraries(common INTERFACE base utils kaldi-matrix frontend)
|
||||
install(TARGETS base DESTINATION lib)
|
||||
install(TARGETS utils DESTINATION lib)
|
||||
install(TARGETS kaldi-matrix DESTINATION lib)
|
||||
install(TARGETS frontend DESTINATION lib)
|
@ -0,0 +1,43 @@
|
||||
|
||||
|
||||
if(WITH_ASR)
|
||||
add_compile_options(-DWITH_ASR)
|
||||
set(PPS_FLAGS_LIB "fst/flags.h")
|
||||
else()
|
||||
set(PPS_FLAGS_LIB "gflags/gflags.h")
|
||||
endif()
|
||||
|
||||
if(ANDROID)
|
||||
set(PPS_GLOG_LIB "base/log_impl.h")
|
||||
else() #UNIX
|
||||
if(WITH_ASR)
|
||||
set(PPS_GLOG_LIB "fst/log.h")
|
||||
else()
|
||||
set(PPS_GLOG_LIB "glog/logging.h")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/flags.h.in
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/flags.h @ONLY
|
||||
)
|
||||
message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/flags.h")
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/log.h.in
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/log.h @ONLY
|
||||
)
|
||||
message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/log.h")
|
||||
|
||||
|
||||
if(ANDROID)
|
||||
set(csrc
|
||||
log_impl.cc
|
||||
glog_utils.cc
|
||||
)
|
||||
add_library(base ${csrc})
|
||||
target_link_libraries(base gflags)
|
||||
else() # UNIX
|
||||
set(csrc)
|
||||
add_library(base INTERFACE)
|
||||
endif()
|
@ -0,0 +1,343 @@
|
||||
// Copyright (c) code is from
|
||||
// https://blog.csdn.net/huixingshao/article/details/45969887.
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma region ParseIniFile
|
||||
#endif
|
||||
|
||||
/*
|
||||
* \brief Generic configuration Class
|
||||
*
|
||||
*/
|
||||
class Config {
|
||||
// Data
|
||||
protected:
|
||||
std::string m_Delimiter; //!< separator between key and value
|
||||
std::string m_Comment; //!< separator between value and comments
|
||||
std::map<std::string, std::string>
|
||||
m_Contents; //!< extracted keys and values
|
||||
|
||||
typedef std::map<std::string, std::string>::iterator mapi;
|
||||
typedef std::map<std::string, std::string>::const_iterator mapci;
|
||||
// Methods
|
||||
public:
|
||||
Config(std::string filename,
|
||||
std::string delimiter = "=",
|
||||
std::string comment = "#");
|
||||
Config();
|
||||
template <class T>
|
||||
T Read(const std::string& in_key) const; //!< Search for key and read value
|
||||
//! or optional default value, call
|
||||
//! as read<T>
|
||||
template <class T>
|
||||
T Read(const std::string& in_key, const T& in_value) const;
|
||||
template <class T>
|
||||
bool ReadInto(T* out_var, const std::string& in_key) const;
|
||||
template <class T>
|
||||
bool ReadInto(T* out_var,
|
||||
const std::string& in_key,
|
||||
const T& in_value) const;
|
||||
bool FileExist(std::string filename);
|
||||
void ReadFile(std::string filename,
|
||||
std::string delimiter = "=",
|
||||
std::string comment = "#");
|
||||
|
||||
// Check whether key exists in configuration
|
||||
bool KeyExists(const std::string& in_key) const;
|
||||
|
||||
// Modify keys and values
|
||||
template <class T>
|
||||
void Add(const std::string& in_key, const T& in_value);
|
||||
void Remove(const std::string& in_key);
|
||||
|
||||
// Check or change configuration syntax
|
||||
std::string GetDelimiter() const { return m_Delimiter; }
|
||||
std::string GetComment() const { return m_Comment; }
|
||||
std::string SetDelimiter(const std::string& in_s) {
|
||||
std::string old = m_Delimiter;
|
||||
m_Delimiter = in_s;
|
||||
return old;
|
||||
}
|
||||
std::string SetComment(const std::string& in_s) {
|
||||
std::string old = m_Comment;
|
||||
m_Comment = in_s;
|
||||
return old;
|
||||
}
|
||||
|
||||
// Write or read configuration
|
||||
friend std::ostream& operator<<(std::ostream& os, const Config& cf);
|
||||
friend std::istream& operator>>(std::istream& is, Config& cf);
|
||||
|
||||
protected:
|
||||
template <class T>
|
||||
static std::string T_as_string(const T& t);
|
||||
template <class T>
|
||||
static T string_as_T(const std::string& s);
|
||||
static void Trim(std::string* inout_s);
|
||||
|
||||
|
||||
// Exception types
|
||||
public:
|
||||
struct File_not_found {
|
||||
std::string filename;
|
||||
explicit File_not_found(const std::string& filename_ = std::string())
|
||||
: filename(filename_) {}
|
||||
};
|
||||
struct Key_not_found { // thrown only by T read(key) variant of read()
|
||||
std::string key;
|
||||
explicit Key_not_found(const std::string& key_ = std::string())
|
||||
: key(key_) {}
|
||||
};
|
||||
};
|
||||
|
||||
/* static */
|
||||
template <class T>
|
||||
std::string Config::T_as_string(const T& t) {
|
||||
// Convert from a T to a string
|
||||
// Type T must support << operator
|
||||
std::ostringstream ost;
|
||||
ost << t;
|
||||
return ost.str();
|
||||
}
|
||||
|
||||
|
||||
/* static */
|
||||
template <class T>
|
||||
T Config::string_as_T(const std::string& s) {
|
||||
// Convert from a string to a T
|
||||
// Type T must support >> operator
|
||||
T t;
|
||||
std::istringstream ist(s);
|
||||
ist >> t;
|
||||
return t;
|
||||
}
|
||||
|
||||
|
||||
/* static */
|
||||
template <>
|
||||
inline std::string Config::string_as_T<std::string>(const std::string& s) {
|
||||
// Convert from a string to a string
|
||||
// In other words, do nothing
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
/* static */
|
||||
template <>
|
||||
inline bool Config::string_as_T<bool>(const std::string& s) {
|
||||
// Convert from a string to a bool
|
||||
// Interpret "false", "F", "no", "n", "0" as false
|
||||
// Interpret "true", "T", "yes", "y", "1", "-1", or anything else as true
|
||||
bool b = true;
|
||||
std::string sup = s;
|
||||
for (std::string::iterator p = sup.begin(); p != sup.end(); ++p)
|
||||
*p = toupper(*p); // make string all caps
|
||||
if (sup == std::string("FALSE") || sup == std::string("F") ||
|
||||
sup == std::string("NO") || sup == std::string("N") ||
|
||||
sup == std::string("0") || sup == std::string("NONE"))
|
||||
b = false;
|
||||
return b;
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
T Config::Read(const std::string& key) const {
|
||||
// Read the value corresponding to key
|
||||
mapci p = m_Contents.find(key);
|
||||
if (p == m_Contents.end()) throw Key_not_found(key);
|
||||
return string_as_T<T>(p->second);
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
T Config::Read(const std::string& key, const T& value) const {
|
||||
// Return the value corresponding to key or given default value
|
||||
// if key is not found
|
||||
mapci p = m_Contents.find(key);
|
||||
if (p == m_Contents.end()) {
|
||||
printf("%s = %s(default)\n", key.c_str(), T_as_string(value).c_str());
|
||||
return value;
|
||||
} else {
|
||||
printf("%s = %s\n", key.c_str(), T_as_string(p->second).c_str());
|
||||
return string_as_T<T>(p->second);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
bool Config::ReadInto(T* var, const std::string& key) const {
|
||||
// Get the value corresponding to key and store in var
|
||||
// Return true if key is found
|
||||
// Otherwise leave var untouched
|
||||
mapci p = m_Contents.find(key);
|
||||
bool found = (p != m_Contents.end());
|
||||
if (found) *var = string_as_T<T>(p->second);
|
||||
return found;
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
bool Config::ReadInto(T* var, const std::string& key, const T& value) const {
|
||||
// Get the value corresponding to key and store in var
|
||||
// Return true if key is found
|
||||
// Otherwise set var to given default
|
||||
mapci p = m_Contents.find(key);
|
||||
bool found = (p != m_Contents.end());
|
||||
if (found)
|
||||
*var = string_as_T<T>(p->second);
|
||||
else
|
||||
var = value;
|
||||
return found;
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
void Config::Add(const std::string& in_key, const T& value) {
|
||||
// Add a key with given value
|
||||
std::string v = T_as_string(value);
|
||||
std::string key = in_key;
|
||||
Trim(&key);
|
||||
Trim(&v);
|
||||
m_Contents[key] = v;
|
||||
return;
|
||||
}
|
||||
|
||||
Config::Config(string filename, string delimiter, string comment)
|
||||
: m_Delimiter(delimiter), m_Comment(comment) {
|
||||
// Construct a Config, getting keys and values from given file
|
||||
|
||||
std::ifstream in(filename.c_str());
|
||||
|
||||
if (!in) throw File_not_found(filename);
|
||||
|
||||
in >> (*this);
|
||||
}
|
||||
|
||||
|
||||
Config::Config() : m_Delimiter(string(1, '=')), m_Comment(string(1, '#')) {
|
||||
// Construct a Config without a file; empty
|
||||
}
|
||||
|
||||
|
||||
bool Config::KeyExists(const string& key) const {
|
||||
// Indicate whether key is found
|
||||
mapci p = m_Contents.find(key);
|
||||
return (p != m_Contents.end());
|
||||
}
|
||||
|
||||
|
||||
/* static */
|
||||
void Config::Trim(string* inout_s) {
|
||||
// Remove leading and trailing whitespace
|
||||
static const char whitespace[] = " \n\t\v\r\f";
|
||||
inout_s->erase(0, inout_s->find_first_not_of(whitespace));
|
||||
inout_s->erase(inout_s->find_last_not_of(whitespace) + 1U);
|
||||
}
|
||||
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Config& cf) {
|
||||
// Save a Config to os
|
||||
for (Config::mapci p = cf.m_Contents.begin(); p != cf.m_Contents.end();
|
||||
++p) {
|
||||
os << p->first << " " << cf.m_Delimiter << " ";
|
||||
os << p->second << std::endl;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
void Config::Remove(const string& key) {
|
||||
// Remove key and its value
|
||||
m_Contents.erase(m_Contents.find(key));
|
||||
return;
|
||||
}
|
||||
|
||||
std::istream& operator>>(std::istream& is, Config& cf) {
|
||||
// Load a Config from is
|
||||
// Read in keys and values, keeping internal whitespace
|
||||
typedef string::size_type pos;
|
||||
const string& delim = cf.m_Delimiter; // separator
|
||||
const string& comm = cf.m_Comment; // comment
|
||||
const pos skip = delim.length(); // length of separator
|
||||
|
||||
string nextline = ""; // might need to read ahead to see where value ends
|
||||
|
||||
while (is || nextline.length() > 0) {
|
||||
// Read an entire line at a time
|
||||
string line;
|
||||
if (nextline.length() > 0) {
|
||||
line = nextline; // we read ahead; use it now
|
||||
nextline = "";
|
||||
} else {
|
||||
std::getline(is, line);
|
||||
}
|
||||
|
||||
// Ignore comments
|
||||
line = line.substr(0, line.find(comm));
|
||||
|
||||
// Parse the line if it contains a delimiter
|
||||
pos delimPos = line.find(delim);
|
||||
if (delimPos < string::npos) {
|
||||
// Extract the key
|
||||
string key = line.substr(0, delimPos);
|
||||
line.replace(0, delimPos + skip, "");
|
||||
|
||||
// See if value continues on the next line
|
||||
// Stop at blank line, next line with a key, end of stream,
|
||||
// or end of file sentry
|
||||
bool terminate = false;
|
||||
while (!terminate && is) {
|
||||
std::getline(is, nextline);
|
||||
terminate = true;
|
||||
|
||||
string nlcopy = nextline;
|
||||
Config::Trim(&nlcopy);
|
||||
if (nlcopy == "") continue;
|
||||
|
||||
nextline = nextline.substr(0, nextline.find(comm));
|
||||
if (nextline.find(delim) != string::npos) continue;
|
||||
|
||||
nlcopy = nextline;
|
||||
Config::Trim(&nlcopy);
|
||||
if (nlcopy != "") line += "\n";
|
||||
line += nextline;
|
||||
terminate = false;
|
||||
}
|
||||
|
||||
// Store key and value
|
||||
Config::Trim(&key);
|
||||
Config::Trim(&line);
|
||||
cf.m_Contents[key] = line; // overwrites if key is repeated
|
||||
}
|
||||
}
|
||||
|
||||
return is;
|
||||
}
|
||||
bool Config::FileExist(std::string filename) {
|
||||
bool exist = false;
|
||||
std::ifstream in(filename.c_str());
|
||||
if (in) exist = true;
|
||||
return exist;
|
||||
}
|
||||
|
||||
void Config::ReadFile(string filename, string delimiter, string comment) {
|
||||
m_Delimiter = delimiter;
|
||||
m_Comment = comment;
|
||||
std::ifstream in(filename.c_str());
|
||||
|
||||
if (!in) throw File_not_found(filename);
|
||||
|
||||
in >> (*this);
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma endregion ParseIniFIle
|
||||
#endif
|
@ -0,0 +1,12 @@
|
||||
|
||||
#include "base/glog_utils.h"
|
||||
|
||||
namespace google {
|
||||
void InitGoogleLogging(const char* name) {
|
||||
LOG(INFO) << "dummpy InitGoogleLogging.";
|
||||
}
|
||||
|
||||
void InstallFailureSignalHandler() {
|
||||
LOG(INFO) << "dummpy InstallFailureSignalHandler.";
|
||||
}
|
||||
} // namespace google
|
@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
|
||||
namespace google {
|
||||
void InitGoogleLogging(const char* name);
|
||||
|
||||
void InstallFailureSignalHandler();
|
||||
} // namespace google
|
@ -0,0 +1,105 @@
|
||||
#include "base/log.h"
|
||||
|
||||
DEFINE_int32(logtostderr, 0, "logging to stderr");
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
static char __progname[] = "paddlespeech";
|
||||
|
||||
namespace log {
|
||||
|
||||
std::mutex LogMessage::lock_;
|
||||
std::string LogMessage::s_debug_logfile_("");
|
||||
std::string LogMessage::s_info_logfile_("");
|
||||
std::string LogMessage::s_warning_logfile_("");
|
||||
std::string LogMessage::s_error_logfile_("");
|
||||
std::string LogMessage::s_fatal_logfile_("");
|
||||
|
||||
void LogMessage::get_curr_proc_info(std::string* pid, std::string* proc_name) {
|
||||
std::stringstream ss;
|
||||
ss << getpid();
|
||||
ss >> *pid;
|
||||
*proc_name = ::ppspeech::__progname;
|
||||
}
|
||||
|
||||
LogMessage::LogMessage(const char* file,
|
||||
int line,
|
||||
Severity level,
|
||||
bool verbose,
|
||||
bool out_to_file /* = false */)
|
||||
: level_(level), verbose_(verbose), out_to_file_(out_to_file) {
|
||||
if (FLAGS_logtostderr == 0) {
|
||||
stream_ = static_cast<std::ostream*>(&std::cout);
|
||||
} else if (FLAGS_logtostderr == 1) {
|
||||
stream_ = static_cast<std::ostream*>(&std::cerr);
|
||||
} else if (out_to_file_) {
|
||||
// logfile
|
||||
lock_.lock();
|
||||
init(file, line);
|
||||
}
|
||||
}
|
||||
|
||||
LogMessage::~LogMessage() {
|
||||
stream() << std::endl;
|
||||
|
||||
if (out_to_file_) {
|
||||
lock_.unlock();
|
||||
}
|
||||
|
||||
if (verbose_ && level_ == FATAL) {
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream* LogMessage::nullstream() {
|
||||
thread_local static std::ofstream os;
|
||||
thread_local static bool flag_set = false;
|
||||
if (!flag_set) {
|
||||
os.setstate(std::ios_base::badbit);
|
||||
flag_set = true;
|
||||
}
|
||||
return &os;
|
||||
}
|
||||
|
||||
void LogMessage::init(const char* file, int line) {
|
||||
time_t t = time(0);
|
||||
char tmp[100];
|
||||
strftime(tmp, sizeof(tmp), "%Y%m%d-%H%M%S", localtime(&t));
|
||||
|
||||
if (s_info_logfile_.empty()) {
|
||||
std::string pid;
|
||||
std::string proc_name;
|
||||
get_curr_proc_info(&pid, &proc_name);
|
||||
|
||||
s_debug_logfile_ =
|
||||
std::string("log." + proc_name + ".log.DEBUG." + tmp + "." + pid);
|
||||
s_info_logfile_ =
|
||||
std::string("log." + proc_name + ".log.INFO." + tmp + "." + pid);
|
||||
s_warning_logfile_ =
|
||||
std::string("log." + proc_name + ".log.WARNING." + tmp + "." + pid);
|
||||
s_error_logfile_ =
|
||||
std::string("log." + proc_name + ".log.ERROR." + tmp + "." + pid);
|
||||
s_fatal_logfile_ =
|
||||
std::string("log." + proc_name + ".log.FATAL." + tmp + "." + pid);
|
||||
}
|
||||
|
||||
thread_local static std::ofstream ofs;
|
||||
if (level_ == DEBUG) {
|
||||
ofs.open(s_debug_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||
} else if (level_ == INFO) {
|
||||
ofs.open(s_info_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||
} else if (level_ == WARNING) {
|
||||
ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||
} else if (level_ == ERROR) {
|
||||
ofs.open(s_error_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||
} else {
|
||||
ofs.open(s_fatal_logfile_.c_str(), std::ios::out | std::ios::app);
|
||||
}
|
||||
|
||||
stream_ = &ofs;
|
||||
|
||||
stream() << tmp << " " << file << " line " << line << "; ";
|
||||
stream() << std::flush;
|
||||
}
|
||||
} // namespace log
|
||||
} // namespace ppspeech
|
@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// modified from https://github.com/Dounm/dlog
|
||||
// modified form
|
||||
// https://android.googlesource.com/platform/art/+/806defa/src/logging.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include "base/common.h"
|
||||
#include "base/macros.h"
|
||||
#ifndef WITH_GLOG
|
||||
#include "base/glog_utils.h"
|
||||
#endif
|
||||
|
||||
DECLARE_int32(logtostderr);
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
namespace log {
|
||||
|
||||
enum Severity {
|
||||
DEBUG,
|
||||
INFO,
|
||||
WARNING,
|
||||
ERROR,
|
||||
FATAL,
|
||||
NUM_SEVERITIES,
|
||||
};
|
||||
|
||||
class LogMessage {
|
||||
public:
|
||||
static void get_curr_proc_info(std::string* pid, std::string* proc_name);
|
||||
|
||||
LogMessage(const char* file,
|
||||
int line,
|
||||
Severity level,
|
||||
bool verbose,
|
||||
bool out_to_file = false);
|
||||
|
||||
~LogMessage();
|
||||
|
||||
std::ostream& stream() { return verbose_ ? *stream_ : *nullstream(); }
|
||||
|
||||
private:
|
||||
void init(const char* file, int line);
|
||||
std::ostream* nullstream();
|
||||
|
||||
private:
|
||||
std::ostream* stream_;
|
||||
std::ostream* null_stream_;
|
||||
Severity level_;
|
||||
bool verbose_;
|
||||
bool out_to_file_;
|
||||
|
||||
static std::mutex lock_; // stream write lock
|
||||
static std::string s_debug_logfile_;
|
||||
static std::string s_info_logfile_;
|
||||
static std::string s_warning_logfile_;
|
||||
static std::string s_error_logfile_;
|
||||
static std::string s_fatal_logfile_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(LogMessage);
|
||||
};
|
||||
|
||||
|
||||
} // namespace log
|
||||
|
||||
} // namespace ppspeech
|
||||
|
||||
|
||||
#ifndef PPS_DEBUG
|
||||
#define DLOG_INFO \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, false)
|
||||
#define DLOG_WARNING \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, false)
|
||||
#define DLOG_ERROR \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, false)
|
||||
#define DLOG_FATAL \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, false)
|
||||
#else
|
||||
#define DLOG_INFO \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true)
|
||||
#define DLOG_WARNING \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true)
|
||||
#define DLOG_ERROR \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true)
|
||||
#define DLOG_FATAL \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true)
|
||||
#endif
|
||||
|
||||
|
||||
#define LOG_INFO \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true)
|
||||
#define LOG_WARNING \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true)
|
||||
#define LOG_ERROR \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true)
|
||||
#define LOG_FATAL \
|
||||
ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true)
|
||||
|
||||
|
||||
#define LOG_0 LOG_DEBUG
|
||||
#define LOG_1 LOG_INFO
|
||||
#define LOG_2 LOG_WARNING
|
||||
#define LOG_3 LOG_ERROR
|
||||
#define LOG_4 LOG_FATAL
|
||||
|
||||
#define LOG(level) LOG_##level.stream()
|
||||
|
||||
#define DLOG(level) DLOG_##level.stream()
|
||||
|
||||
#define VLOG(verboselevel) LOG(verboselevel)
|
||||
|
||||
#define CHECK(exp) \
|
||||
ppspeech::log::LogMessage( \
|
||||
__FILE__, __LINE__, ppspeech::log::FATAL, !(exp)) \
|
||||
.stream() \
|
||||
<< "Check Failed: " #exp
|
||||
|
||||
#define CHECK_EQ(x, y) CHECK((x) == (y))
|
||||
#define CHECK_NE(x, y) CHECK((x) != (y))
|
||||
#define CHECK_LE(x, y) CHECK((x) <= (y))
|
||||
#define CHECK_LT(x, y) CHECK((x) < (y))
|
||||
#define CHECK_GE(x, y) CHECK((x) >= (y))
|
||||
#define CHECK_GT(x, y) CHECK((x) > (y))
|
||||
#ifdef PPS_DEBUG
|
||||
#define DCHECK(x) CHECK(x)
|
||||
#define DCHECK_EQ(x, y) CHECK_EQ(x, y)
|
||||
#define DCHECK_NE(x, y) CHECK_NE(x, y)
|
||||
#define DCHECK_LE(x, y) CHECK_LE(x, y)
|
||||
#define DCHECK_LT(x, y) CHECK_LT(x, y)
|
||||
#define DCHECK_GE(x, y) CHECK_GE(x, y)
|
||||
#define DCHECK_GT(x, y) CHECK_GT(x, y)
|
||||
#else
|
||||
#define DCHECK(condition) \
|
||||
while (false) CHECK(condition)
|
||||
#define DCHECK_EQ(val1, val2) \
|
||||
while (false) CHECK_EQ(val1, val2)
|
||||
#define DCHECK_NE(val1, val2) \
|
||||
while (false) CHECK_NE(val1, val2)
|
||||
#define DCHECK_LE(val1, val2) \
|
||||
while (false) CHECK_LE(val1, val2)
|
||||
#define DCHECK_LT(val1, val2) \
|
||||
while (false) CHECK_LT(val1, val2)
|
||||
#define DCHECK_GE(val1, val2) \
|
||||
while (false) CHECK_GE(val1, val2)
|
||||
#define DCHECK_GT(val1, val2) \
|
||||
while (false) CHECK_GT(val1, val2)
|
||||
#define DCHECK_STREQ(str1, str2) \
|
||||
while (false) CHECK_STREQ(str1, str2)
|
||||
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue