diff --git a/.mergify.yml b/.mergify.yml index 5cb1f486..0f182b51 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -136,7 +136,7 @@ pull_request_rules: add: ["Docker"] - name: "auto add label=Deployment" conditions: - - files~=^speechx/ + - files~=^runtime/ actions: label: add: ["Deployment"] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 53fc6ba0..6afa7c9c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,8 +3,12 @@ repos: rev: v0.16.0 hooks: - id: yapf - files: \.py$ - exclude: (?=third_party).*(\.py)$ + name: yapf + language: python + entry: yapf + args: [-i, -vv] + types: [python] + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: a11d9314b22d8f8c7556443875b731ef05965464 @@ -31,7 +35,7 @@ repos: - --ignore=E501,E228,E226,E261,E266,E128,E402,W503 - --builtins=G,request - --jobs=1 - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ - repo : https://github.com/Lucas-C/pre-commit-hooks rev: v1.0.1 @@ -53,16 +57,16 @@ repos: entry: bash .pre-commit-hooks/clang-format.hook -i language: system files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ - id: cpplint name: cpplint description: Static code analysis of C/C++ files language: python files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|runtime/engine/common/matrix|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 hooks: - id: reorder-python-imports - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$ diff --git a/README.md b/README.md index 281960a2..9ed82311 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision - 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation. - 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](./demos/speech_ssl), Support ASR and Feature Extraction. - 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660). -- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](./speechx/examples/u2pp_ol/wenetspeech). +- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/runtime/examples/u2pp_ol/wenetspeech). - 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3). - 🔥 2022.10.26: Add [Prosody Prediction](./examples/other/rhy) for TTS. - 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend. diff --git a/speechx/.clang-format b/runtime/.clang-format similarity index 100% rename from speechx/.clang-format rename to runtime/.clang-format diff --git a/runtime/.gitignore b/runtime/.gitignore new file mode 100644 index 00000000..a654dae4 --- /dev/null +++ b/runtime/.gitignore @@ -0,0 +1,7 @@ +engine/common/base/flags.h +engine/common/base/log.h + +tools/valgrind* +*log +fc_patch/* +test diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt new file mode 100644 index 00000000..092c8b25 --- /dev/null +++ b/runtime/CMakeLists.txt @@ -0,0 +1,211 @@ +# >=3.17 support -DCMAKE_FIND_DEBUG_MODE=ON +cmake_minimum_required(VERSION 3.17 FATAL_ERROR) + +set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0077.cmake") + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +include(system) + +project(paddlespeech VERSION 0.1) + +set(PPS_VERSION_MAJOR 1) +set(PPS_VERSION_MINOR 0) +set(PPS_VERSION_PATCH 0) +set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}") + +# compiler option +# Keep the same with openfst, -fPIC or -fpic +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl") +SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") +SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall") + +set(CMAKE_VERBOSE_MAKEFILE ON) +set(CMAKE_FIND_DEBUG_MODE OFF) +set(PPS_CXX_STANDARD 14) + +# set std-14 +set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD}) + +# Ninja Generator will set CMAKE_BUILD_TYPE to Debug +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE) +endif() + +# find_* e.g. find_library work when Cross-Compiling +if(ANDROID) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) +endif() + +if(BUILD_IN_MACOS) + add_definitions("-DOS_MACOSX") +endif() + +# install dir into `build/install` +set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install) + +include(FetchContent) +include(ExternalProject) + +# fc_patch dir +set(FETCHCONTENT_QUIET off) +get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}") +set(FETCHCONTENT_BASE_DIR ${fc_patch}) + +############################################################################### +# Option Configurations +############################################################################### +# https://github.com/google/brotli/pull/655 +option(BUILD_SHARED_LIBS "Build shared libraries" ON) + +option(WITH_PPS_DEBUG "debug option" OFF) +if (WITH_PPS_DEBUG) + add_definitions("-DPPS_DEBUG") +endif() + +option(WITH_ASR "build asr" ON) +option(WITH_CLS "build cls" ON) +option(WITH_VAD "build vad" ON) + +option(WITH_GPU "NNet using GPU." OFF) + +option(WITH_PROFILING "enable c++ profling" OFF) +option(WITH_TESTING "unit test" ON) + +option(WITH_ONNX "u2 support onnx runtime" OFF) + +############################################################################### +# Include Third Party +############################################################################### +include(gflags) + +include(glog) + +include(pybind) + +#onnx +if(WITH_ONNX) + add_definitions(-DUSE_ONNX) +endif() + +# gtest +if(WITH_TESTING) + include(gtest) # download, build, install gtest +endif() + +# fastdeploy +include(fastdeploy) + +if(WITH_ASR) + # openfst + include(openfst) + add_dependencies(openfst gflags extern_glog) +endif() + +############################################################################### +# Find Package +############################################################################### +# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207 +find_package(Threads REQUIRED) + +if(WITH_ASR) + # https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3 + find_package(Python3 COMPONENTS Interpreter Development) + find_package(pybind11 CONFIG) + + if(Python3_FOUND) + message(STATUS "Python3_FOUND = ${Python3_FOUND}") + message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}") + message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}") + message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}") + message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}") + set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE) + set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE) + endif() + + message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}") + message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}") + include_directories(${PYTHON_INCLUDE_DIR}) + + if(pybind11_FOUND) + message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}") + message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}") + message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}") + endif() + + + # paddle libpaddle.so + # paddle include and link option + # -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so + set(EXECUTE_COMMAND "import os" + "import paddle" + "include_dir = paddle.sysconfig.get_include()" + "paddle_dir=os.path.split(include_dir)[0]" + "libs_dir=os.path.join(paddle_dir, 'libs')" + "fluid_dir=os.path.join(paddle_dir, 'fluid')" + "out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir])" + "out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out)" + ) + execute_process( + COMMAND python -c "${EXECUTE_COMMAND}" + OUTPUT_VARIABLE PADDLE_LINK_FLAGS + RESULT_VARIABLE SUCESS) + + message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS}) + string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS) + + # paddle compile option + # -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include + set(EXECUTE_COMMAND "import paddle" + "include_dir = paddle.sysconfig.get_include()" + "print(f\"-I{include_dir}\")" + ) + execute_process( + COMMAND python -c "${EXECUTE_COMMAND}" + OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS) + message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS}) + string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS) + + # for LD_LIBRARY_PATH + # set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/) + set(EXECUTE_COMMAND "import os" + "import paddle" + "include_dir=paddle.sysconfig.get_include()" + "paddle_dir=os.path.split(include_dir)[0]" + "libs_dir=os.path.join(paddle_dir, 'libs')" + "fluid_dir=os.path.join(paddle_dir, 'fluid')" + "out=':'.join([libs_dir, fluid_dir]); print(out)" + ) + execute_process( + COMMAND python -c "${EXECUTE_COMMAND}" + OUTPUT_VARIABLE PADDLE_LIB_DIRS) + message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) +endif() + +include(summary) + +############################################################################### +# Add local library +############################################################################### +set(ENGINE_ROOT ${CMAKE_SOURCE_DIR}/engine) + +add_subdirectory(engine) + + +############################################################################### +# CPack library +############################################################################### +# build a CPack driven installer package +include (InstallRequiredSystemLibraries) +set(CPACK_PACKAGE_NAME "paddlespeech_library") +set(CPACK_PACKAGE_VENDOR "paddlespeech") +set(CPACK_PACKAGE_VERSION_MAJOR 1) +set(CPACK_PACKAGE_VERSION_MINOR 0) +set(CPACK_PACKAGE_VERSION_PATCH 0) +set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library") +set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com") +set(CPACK_SOURCE_GENERATOR "TGZ") +include (CPack) diff --git a/speechx/README.md b/runtime/README.md similarity index 92% rename from speechx/README.md rename to runtime/README.md index 5d4b5845..40aa9444 100644 --- a/speechx/README.md +++ b/runtime/README.md @@ -1,4 +1,3 @@ -# SpeechX -- All in One Speech Task Inference ## Environment @@ -9,7 +8,7 @@ We develop under: * gcc/g++/gfortran - 8.2.0 * cmake - 3.16.0 -> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx. +> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build engine. > We make sure all things work fun under docker, and recommend using it to develop and deploy. @@ -33,7 +32,7 @@ docker run --privileged --net=host --ipc=host -it --rm -v /path/to/paddlespeech bash tools/venv.sh ``` -2. Build `speechx` and `examples`. +2. Build `engine` and `examples`. For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version. For example: @@ -113,3 +112,11 @@ apt-get install gfortran-8 4. `Undefined reference to '_gfortran_concat_string'` using gcc 8.2, gfortran 8.2. + +5. `./boost/python/detail/wrap_python.hpp:57:11: fatal error: pyconfig.h: No such file or directory` + +``` +apt-get install python3-dev +``` + +for more info please see [here](https://github.com/okfn/piati/issues/65). diff --git a/runtime/build.sh b/runtime/build.sh new file mode 100755 index 00000000..68889010 --- /dev/null +++ b/runtime/build.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -xe + +BUILD_ROOT=build/Linux +BUILD_DIR=${BUILD_ROOT}/x86_64 + +mkdir -p ${BUILD_DIR} + +BUILD_TYPE=Release +#BUILD_TYPE=Debug +BUILD_SO=OFF +BUILD_ONNX=ON +BUILD_ASR=ON +BUILD_CLS=ON +BUILD_VAD=ON +PPS_DEBUG=OFF +FASTDEPLOY_INSTALL_DIR="" + +# the build script had verified in the paddlepaddle docker image. +# please follow the instruction below to install PaddlePaddle image. +# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html +#cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install +cmake -B ${BUILD_DIR} \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DBUILD_SHARED_LIBS=${BUILD_SO} \ + -DWITH_ONNX=${BUILD_ONNX} \ + -DWITH_ASR=${BUILD_ASR} \ + -DWITH_CLS=${BUILD_CLS} \ + -DWITH_VAD=${BUILD_VAD} \ + -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ + -DWITH_PPS_DEBUG=${PPS_DEBUG} + +cmake --build ${BUILD_DIR} -j diff --git a/runtime/build_android.sh b/runtime/build_android.sh new file mode 100755 index 00000000..ce78e67c --- /dev/null +++ b/runtime/build_android.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +set -ex + +ANDROID_NDK=/mnt/masimeng/workspace/software/android-ndk-r25b/ + +# Setting up Android toolchanin +ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a' +ANDROID_PLATFORM="android-21" # API >= 21 +ANDROID_STL=c++_shared # 'c++_shared', 'c++_static' +ANDROID_TOOLCHAIN=clang # 'clang' only +TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake + +# Create build directory +BUILD_ROOT=build/Android +BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21 +FASTDEPLOY_INSTALL_DIR="/mnt/masimeng/workspace/FastDeploy/build/Android/arm64-v8a-api-21/install" + +mkdir -p ${BUILD_DIR} +cd ${BUILD_DIR} + +# CMake configuration with Android toolchain +cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DANDROID_ABI=${ANDROID_ABI} \ + -DANDROID_NDK=${ANDROID_NDK} \ + -DANDROID_PLATFORM=${ANDROID_PLATFORM} \ + -DANDROID_STL=${ANDROID_STL} \ + -DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \ + -DBUILD_SHARED_LIBS=OFF \ + -DWITH_ASR=OFF \ + -DWITH_CLS=OFF \ + -DWITH_VAD=ON \ + -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ + -DCMAKE_FIND_DEBUG_MODE=OFF \ + -Wno-dev ../../.. + +# Build FastDeploy Android C++ SDK +make diff --git a/runtime/build_ios.sh b/runtime/build_ios.sh new file mode 100644 index 00000000..74f76bf6 --- /dev/null +++ b/runtime/build_ios.sh @@ -0,0 +1,91 @@ +# https://www.jianshu.com/p/33672fb819f5 + +PATH="/Applications/CMake.app/Contents/bin":"$PATH" +tools_dir=$1 +ios_toolchain_cmake=${tools_dir}/"/ios-cmake-4.2.0/ios.toolchain.cmake" +fastdeploy_dir=${tools_dir}"/fastdeploy-ort-mac-build/" +build_targets=("OS64") +build_type_array=("Release") + +#static_name="libocr" +#lib_name="libocr" + +# Switch to workpath +current_path=`cd $(dirname $0);pwd` +work_path=${current_path}/ +build_path=${current_path}/build/ +output_path=${current_path}/output/ +cd ${work_path} + +# Clean +rm -rf ${build_path} +rm -rf ${output_path} + +if [ "$1"x = "clean"x ]; then + exit 0 +fi + +# Build Every Target +for target in "${build_targets[@]}" +do + for build_type in "${build_type_array[@]}" + do + echo -e "\033[1;36;40mBuilding ${build_type} ${target} ... \033[0m" + target_build_path=${build_path}/${target}/${build_type}/ + mkdir -p ${target_build_path} + + cd ${target_build_path} + if [ $? -ne 0 ];then + echo -e "\033[1;31;40mcd ${target_build_path} failed \033[0m" + exit -1 + fi + + if [ ${target} == "OS64" ];then + fastdeploy_install_dir=${fastdeploy_dir}/arm64 + else + fastdeploy_install_dir="" + echo "fastdeploy_install_dir is null" + exit -1 + fi + + cmake -DCMAKE_TOOLCHAIN_FILE=${ios_toolchain_cmake} \ + -DBUILD_IN_MACOS=ON \ + -DBUILD_SHARED_LIBS=OFF \ + -DWITH_ASR=OFF \ + -DWITH_CLS=OFF \ + -DWITH_VAD=ON \ + -DFASTDEPLOY_INSTALL_DIR=${fastdeploy_install_dir} \ + -DPLATFORM=${target} ../../../ + + cmake --build . --config ${build_type} + + mkdir output + cp engine/vad/interface/libpps_vad_interface.a output + cp engine/vad/interface/vad_interface_main.app/vad_interface_main output + cp ${fastdeploy_install_dir}/lib/libfastdeploy.dylib output + cp ${fastdeploy_install_dir}/third_libs/install/onnxruntime/lib/libonnxruntime.dylib output + + done +done + +## combine all ios libraries +#DEVROOT=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/ +#LIPO_TOOL=${DEVROOT}/usr/bin/lipo +#LIBRARY_PATH=${build_path} +#LIBRARY_OUTPUT_PATH=${output_path}/IOS +#mkdir -p ${LIBRARY_OUTPUT_PATH} +# +#${LIPO_TOOL} \ +# -arch i386 ${LIBRARY_PATH}/ios_x86/Release/${lib_name}.a \ +# -arch x86_64 ${LIBRARY_PATH}/ios_x86_64/Release/${lib_name}.a \ +# -arch armv7 ${LIBRARY_PATH}/ios_armv7/Release/${lib_name}.a \ +# -arch armv7s ${LIBRARY_PATH}/ios_armv7s/Release/${lib_name}.a \ +# -arch arm64 ${LIBRARY_PATH}/ios_armv8/Release/${lib_name}.a \ +# -output ${LIBRARY_OUTPUT_PATH}/${lib_name}.a -create +# +#cp ${work_path}/lib/houyi/lib/ios/libhouyi_score.a ${LIBRARY_OUTPUT_PATH}/ +#cp ${work_path}/interface/ocr-interface.h ${output_path} +#cp ${work_path}/version/release.v ${output_path} +# +#echo -e "\033[1;36;40mBuild All Target Success At:\n${output_path}\033[0m" +#exit 0 diff --git a/speechx/cmake/EnableCMP0048.cmake b/runtime/cmake/EnableCMP0048.cmake similarity index 100% rename from speechx/cmake/EnableCMP0048.cmake rename to runtime/cmake/EnableCMP0048.cmake diff --git a/runtime/cmake/EnableCMP0077.cmake b/runtime/cmake/EnableCMP0077.cmake new file mode 100644 index 00000000..a7deaffb --- /dev/null +++ b/runtime/cmake/EnableCMP0077.cmake @@ -0,0 +1 @@ +cmake_policy(SET CMP0077 NEW) diff --git a/speechx/cmake/FindGFortranLibs.cmake b/runtime/cmake/FindGFortranLibs.cmake similarity index 100% rename from speechx/cmake/FindGFortranLibs.cmake rename to runtime/cmake/FindGFortranLibs.cmake diff --git a/speechx/cmake/absl.cmake b/runtime/cmake/absl.cmake similarity index 100% rename from speechx/cmake/absl.cmake rename to runtime/cmake/absl.cmake diff --git a/speechx/cmake/boost.cmake b/runtime/cmake/boost.cmake similarity index 100% rename from speechx/cmake/boost.cmake rename to runtime/cmake/boost.cmake diff --git a/speechx/cmake/eigen.cmake b/runtime/cmake/eigen.cmake similarity index 100% rename from speechx/cmake/eigen.cmake rename to runtime/cmake/eigen.cmake diff --git a/runtime/cmake/fastdeploy.cmake b/runtime/cmake/fastdeploy.cmake new file mode 100644 index 00000000..e095cd4c --- /dev/null +++ b/runtime/cmake/fastdeploy.cmake @@ -0,0 +1,116 @@ +include(FetchContent) + +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD 1 # Wrap download in script to log output + LOG_UPDATE 1 # Wrap update in script to log output + LOG_PATCH 1 + LOG_CONFIGURE 1# Wrap configure in script to log output + LOG_BUILD 1 # Wrap build in script to log output + LOG_INSTALL 1 + LOG_TEST 1 # Wrap test in script to log output + LOG_MERGED_STDOUTERR 1 + LOG_OUTPUT_ON_FAILURE 1 +) + +if(NOT FASTDEPLOY_INSTALL_DIR) + if(ANDROID) + FetchContent_Declare( + fastdeploy + URL https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz + URL_HASH MD5=2a15301158e9eb157a4f11283689e7ba + ${EXTERNAL_PROJECT_LOG_ARGS} + ) + add_definitions("-DUSE_PADDLE_LITE_BAKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + else() # Linux + FetchContent_Declare( + fastdeploy + URL https://paddlespeech.bj.bcebos.com/speechx/fastdeploy/fastdeploy-1.0.5-x86_64-onnx.tar.gz + URL_HASH MD5=33900d986ea71aa78635e52f0733227c + ${EXTERNAL_PROJECT_LOG_ARGS} + ) + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3") + endif() + + FetchContent_MakeAvailable(fastdeploy) + + set(FASTDEPLOY_INSTALL_DIR ${fc_patch}/fastdeploy-src) +endif() + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# fix compiler flags conflict, since fastdeploy using c++11 for project +# this line must after `include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)` +set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD}) + +include_directories(${FASTDEPLOY_INCS}) + +# install fastdeploy and dependents lib +# install_fastdeploy_libraries(${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) +# No dynamic libs need to install while using +# FastDeploy static lib. +if(ANDROID AND WITH_ANDROID_STATIC_LIB) + return() +endif() + +set(DYN_LIB_SUFFIX "*.so*") +if(WIN32) + set(DYN_LIB_SUFFIX "*.dll") +elseif(APPLE) + set(DYN_LIB_SUFFIX "*.dylib*") +endif() + +if(FastDeploy_DIR) + set(DYN_SEARCH_DIR ${FastDeploy_DIR}) +elseif(FASTDEPLOY_INSTALL_DIR) + set(DYN_SEARCH_DIR ${FASTDEPLOY_INSTALL_DIR}) +else() + message(FATAL_ERROR "Please set FastDeploy_DIR/FASTDEPLOY_INSTALL_DIR before call install_fastdeploy_libraries.") +endif() + +file(GLOB_RECURSE ALL_NEED_DYN_LIBS ${DYN_SEARCH_DIR}/lib/${DYN_LIB_SUFFIX}) +file(GLOB_RECURSE ALL_DEPS_DYN_LIBS ${DYN_SEARCH_DIR}/third_libs/${DYN_LIB_SUFFIX}) + +if(ENABLE_VISION) + # OpenCV + if(ANDROID) + file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${DYN_LIB_SUFFIX}) + else() + file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_DIR}/../../${DYN_LIB_SUFFIX}) + endif() + + list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_OPENCV_DYN_LIBS}) + + if(WIN32) + file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/x64/vc15/bin/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + elseif(ANDROID AND (NOT WITH_ANDROID_OPENCV_STATIC)) + file(GLOB OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${ANDROID_ABI}/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + else() # linux/mac + file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/lib/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + endif() + + # FlyCV + if(ENABLE_FLYCV) + file(GLOB_RECURSE ALL_FLYCV_DYN_LIBS ${FLYCV_LIB_DIR}/${DYN_LIB_SUFFIX}) + list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_FLYCV_DYN_LIBS}) + if(ANDROID AND (NOT WITH_ANDROID_FLYCV_STATIC)) + install(FILES ${ALL_FLYCV_DYN_LIBS} DESTINATION lib) + endif() + endif() +endif() + +if(ENABLE_OPENVINO_BACKEND) + # need plugins.xml for openvino backend + set(OPENVINO_RUNTIME_BIN_DIR ${OPENVINO_DIR}/bin) + file(GLOB OPENVINO_PLUGIN_XML ${OPENVINO_RUNTIME_BIN_DIR}/*.xml) + install(FILES ${OPENVINO_PLUGIN_XML} DESTINATION lib) +endif() + +# Install other libraries +install(FILES ${ALL_NEED_DYN_LIBS} DESTINATION lib) +install(FILES ${ALL_DEPS_DYN_LIBS} DESTINATION lib) diff --git a/runtime/cmake/gflags.cmake b/runtime/cmake/gflags.cmake new file mode 100644 index 00000000..aa0248ba --- /dev/null +++ b/runtime/cmake/gflags.cmake @@ -0,0 +1,14 @@ +include(FetchContent) + +FetchContent_Declare( + gflags + URL https://paddleaudio.bj.bcebos.com/build/gflag-2.2.2.zip + URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5 +) +FetchContent_MakeAvailable(gflags) + +# openfst need +include_directories(${gflags_BINARY_DIR}/include) +link_directories(${gflags_BINARY_DIR}) + +#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib) diff --git a/runtime/cmake/glog.cmake b/runtime/cmake/glog.cmake new file mode 100644 index 00000000..6c38963a --- /dev/null +++ b/runtime/cmake/glog.cmake @@ -0,0 +1,35 @@ +include(FetchContent) + +if(ANDROID) +else() # UNIX + add_definitions(-DWITH_GLOG) + FetchContent_Declare( + glog + URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip + URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} + -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} + -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DWITH_GFLAGS=OFF + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + ${EXTERNAL_OPTIONAL_ARGS} + ) + set(BUILD_TESTING OFF) + FetchContent_MakeAvailable(glog) + include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src) +endif() + + +if(ANDROID) + add_library(extern_glog INTERFACE) + add_dependencies(extern_glog gflags) +else() # UNIX + add_library(extern_glog ALIAS glog) + add_dependencies(glog gflags) +endif() \ No newline at end of file diff --git a/runtime/cmake/gtest.cmake b/runtime/cmake/gtest.cmake new file mode 100644 index 00000000..a311721f --- /dev/null +++ b/runtime/cmake/gtest.cmake @@ -0,0 +1,27 @@ + +include(FetchContent) + +if(ANDROID) +else() # UNIX + FetchContent_Declare( + gtest + URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip + URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a + ) + FetchContent_MakeAvailable(gtest) + + include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src) +endif() + + + +if(ANDROID) + add_library(extern_gtest INTERFACE) +else() # UNIX + add_dependencies(gtest gflags extern_glog) + add_library(extern_gtest ALIAS gtest) +endif() + +if(WITH_TESTING) + enable_testing() +endif() diff --git a/speechx/cmake/kenlm.cmake b/runtime/cmake/kenlm.cmake similarity index 100% rename from speechx/cmake/kenlm.cmake rename to runtime/cmake/kenlm.cmake diff --git a/speechx/cmake/libsndfile.cmake b/runtime/cmake/libsndfile.cmake similarity index 100% rename from speechx/cmake/libsndfile.cmake rename to runtime/cmake/libsndfile.cmake diff --git a/speechx/cmake/openblas.cmake b/runtime/cmake/openblas.cmake similarity index 100% rename from speechx/cmake/openblas.cmake rename to runtime/cmake/openblas.cmake diff --git a/speechx/cmake/openfst.cmake b/runtime/cmake/openfst.cmake similarity index 69% rename from speechx/cmake/openfst.cmake rename to runtime/cmake/openfst.cmake index 07c33a74..42299c88 100644 --- a/speechx/cmake/openfst.cmake +++ b/runtime/cmake/openfst.cmake @@ -1,8 +1,8 @@ -include(FetchContent) set(openfst_PREFIX_DIR ${fc_patch}/openfst) set(openfst_SOURCE_DIR ${fc_patch}/openfst-src) set(openfst_BINARY_DIR ${fc_patch}/openfst-build) +include(FetchContent) # openfst Acknowledgments: #Cyril Allauzen, Michael Riley, Johan Schalkwyk, Wojciech Skut and Mehryar Mohri, #"OpenFst: A General and Efficient Weighted Finite-State Transducer Library", @@ -10,18 +10,33 @@ set(openfst_BINARY_DIR ${fc_patch}/openfst-build) #Application of Automata, (CIAA 2007), volume 4783 of Lecture Notes in #Computer Science, pages 11-23. Springer, 2007. http://www.openfst.org. +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD 1 # Wrap download in script to log output + LOG_UPDATE 1 # Wrap update in script to log output + LOG_CONFIGURE 1# Wrap configure in script to log output + LOG_BUILD 1 # Wrap build in script to log output + LOG_TEST 1 # Wrap test in script to log output + LOG_INSTALL 1 # Wrap install in script to log output +) + ExternalProject_Add(openfst URL https://paddleaudio.bj.bcebos.com/build/openfst_1.7.2.zip URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6 + ${EXTERNAL_PROJECT_LOG_ARGS} PREFIX ${openfst_PREFIX_DIR} SOURCE_DIR ${openfst_SOURCE_DIR} BINARY_DIR ${openfst_BINARY_DIR} + BUILD_ALWAYS 0 CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR} "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" - "LIBS=-lgflags_nothreads -lglog -lpthread" + "LIBS=-lgflags_nothreads -lglog -lpthread -fPIC" COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} BUILD_COMMAND make -j 4 ) link_directories(${openfst_PREFIX_DIR}/lib) include_directories(${openfst_PREFIX_DIR}/include) + + +message(STATUS "OpenFST inc dir: ${openfst_PREFIX_DIR}/include") +message(STATUS "OpenFST lib dir: ${openfst_PREFIX_DIR}/lib") diff --git a/speechx/cmake/paddleinference.cmake b/runtime/cmake/paddleinference.cmake similarity index 100% rename from speechx/cmake/paddleinference.cmake rename to runtime/cmake/paddleinference.cmake diff --git a/runtime/cmake/pybind.cmake b/runtime/cmake/pybind.cmake new file mode 100644 index 00000000..0ce1f57f --- /dev/null +++ b/runtime/cmake/pybind.cmake @@ -0,0 +1,42 @@ +#the pybind11 is from:https://github.com/pybind/pybind11 +# Copyright (c) 2016 Wenzel Jakob , All rights reserved. + +SET(PYBIND_ZIP "v2.10.0.zip") +SET(LOCAL_PYBIND_ZIP ${FETCHCONTENT_BASE_DIR}/${PYBIND_ZIP}) +SET(PYBIND_SRC ${FETCHCONTENT_BASE_DIR}/pybind11) +SET(DOWNLOAD_URL "https://paddleaudio.bj.bcebos.com/build/v2.10.0.zip") +SET(PYBIND_TIMEOUT 600 CACHE STRING "Timeout in seconds when downloading pybind.") + +IF(NOT EXISTS ${LOCAL_PYBIND_ZIP}) + FILE(DOWNLOAD ${DOWNLOAD_URL} + ${LOCAL_PYBIND_ZIP} + TIMEOUT ${PYBIND_TIMEOUT} + STATUS ERR + SHOW_PROGRESS + ) + + IF(ERR EQUAL 0) + MESSAGE(STATUS "download pybind success") + ELSE() + MESSAGE(FATAL_ERROR "download pybind fail") + ENDIF() +ENDIF() + +IF(NOT EXISTS ${PYBIND_SRC}) + EXECUTE_PROCESS( + COMMAND ${CMAKE_COMMAND} -E tar xfz ${LOCAL_PYBIND_ZIP} + WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR} + RESULT_VARIABLE tar_result + ) + + file(RENAME ${FETCHCONTENT_BASE_DIR}/pybind11-2.10.0 ${PYBIND_SRC}) + + IF (tar_result MATCHES 0) + MESSAGE(STATUS "unzip pybind success") + ELSE() + MESSAGE(FATAL_ERROR "unzip pybind fail") + ENDIF() + +ENDIF() + +include_directories(${PYBIND_SRC}/include) diff --git a/runtime/cmake/summary.cmake b/runtime/cmake/summary.cmake new file mode 100644 index 00000000..95ee324a --- /dev/null +++ b/runtime/cmake/summary.cmake @@ -0,0 +1,64 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +function(pps_summary) + message(STATUS "") + message(STATUS "*************PaddleSpeech Building Summary**********") + message(STATUS " PPS_VERSION : ${PPS_VERSION}") + message(STATUS " CMake version : ${CMAKE_VERSION}") + message(STATUS " CMake command : ${CMAKE_COMMAND}") + message(STATUS " UNIX : ${UNIX}") + message(STATUS " ANDROID : ${ANDROID}") + message(STATUS " System : ${CMAKE_SYSTEM_NAME}") + message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}") + message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}") + message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}") + message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") + message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}") + get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS) + message(STATUS " Compile definitions : ${tmp}") + message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}") + message(STATUS " CMAKE_CURRENT_BINARY_DIR : ${CMAKE_CURRENT_BINARY_DIR}") + message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}") + message(STATUS " CMAKE_INSTALL_LIBDIR : ${CMAKE_INSTALL_LIBDIR}") + message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}") + message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}") + message(STATUS "") + + message(STATUS " WITH_ASR : ${WITH_ASR}") + message(STATUS " WITH_CLS : ${WITH_CLS}") + message(STATUS " WITH_VAD : ${WITH_VAD}") + message(STATUS " WITH_GPU : ${WITH_GPU}") + message(STATUS " WITH_TESTING : ${WITH_TESTING}") + message(STATUS " WITH_PROFILING : ${WITH_PROFILING}") + message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}") + message(STATUS " FASTDEPLOY_INCS : ${FASTDEPLOY_INCS}") + message(STATUS " FASTDEPLOY_LIBS : ${FASTDEPLOY_LIBS}") + if(WITH_GPU) + message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}") + endif() + + if(ANDROID) + message(STATUS " ANDROID_ABI : ${ANDROID_ABI}") + message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}") + message(STATUS " ANDROID_NDK : ${ANDROID_NDK}") + message(STATUS " ANDROID_NDK_VERSION : ${CMAKE_ANDROID_NDK_VERSION}") + endif() + if (WITH_ASR) + message(STATUS " Python executable : ${PYTHON_EXECUTABLE}") + message(STATUS " Python includes : ${PYTHON_INCLUDE_DIR}") + endif() +endfunction() + +pps_summary() \ No newline at end of file diff --git a/runtime/cmake/system.cmake b/runtime/cmake/system.cmake new file mode 100644 index 00000000..580e07bb --- /dev/null +++ b/runtime/cmake/system.cmake @@ -0,0 +1,106 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Detects the OS and sets appropriate variables. +# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is +# building for, but the host processor name like centos is necessary +# in some scenes to distinguish system for customization. +# +# for instance, protobuf libs path is /lib64 +# on CentOS, but /lib on other systems. + +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif() + +if(WIN32) + set(HOST_SYSTEM "win32") +else() + if(APPLE) + set(HOST_SYSTEM "macosx") + exec_program( + sw_vers ARGS + -productVersion + OUTPUT_VARIABLE HOST_SYSTEM_VERSION) + string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}") + if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET}) + # Set cache variable - end user may change this during ccmake or cmake-gui configure. + set(CMAKE_OSX_DEPLOYMENT_TARGET + ${MACOS_VERSION} + CACHE + STRING + "Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value." + ) + endif() + set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") + else() + + if(EXISTS "/etc/issue") + file(READ "/etc/issue" LINUX_ISSUE) + if(LINUX_ISSUE MATCHES "CentOS") + set(HOST_SYSTEM "centos") + elseif(LINUX_ISSUE MATCHES "Debian") + set(HOST_SYSTEM "debian") + elseif(LINUX_ISSUE MATCHES "Ubuntu") + set(HOST_SYSTEM "ubuntu") + elseif(LINUX_ISSUE MATCHES "Red Hat") + set(HOST_SYSTEM "redhat") + elseif(LINUX_ISSUE MATCHES "Fedora") + set(HOST_SYSTEM "fedora") + endif() + + string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION + "${LINUX_ISSUE}") + endif() + + if(EXISTS "/etc/redhat-release") + file(READ "/etc/redhat-release" LINUX_ISSUE) + if(LINUX_ISSUE MATCHES "CentOS") + set(HOST_SYSTEM "centos") + endif() + endif() + + if(NOT HOST_SYSTEM) + set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME}) + endif() + + endif() +endif() + +# query number of logical cores +cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES) + +mark_as_advanced(HOST_SYSTEM CPU_CORES) + +message( + STATUS + "Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}") +message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores") + +# external dependencies log output +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD + 0 # Wrap download in script to log output + LOG_UPDATE + 1 # Wrap update in script to log output + LOG_CONFIGURE + 1 # Wrap configure in script to log output + LOG_BUILD + 0 # Wrap build in script to log output + LOG_TEST + 1 # Wrap test in script to log output + LOG_INSTALL + 0 # Wrap install in script to log output +) \ No newline at end of file diff --git a/speechx/docker/.gitkeep b/runtime/docker/.gitkeep similarity index 100% rename from speechx/docker/.gitkeep rename to runtime/docker/.gitkeep diff --git a/runtime/engine/CMakeLists.txt b/runtime/engine/CMakeLists.txt new file mode 100644 index 00000000..d64df648 --- /dev/null +++ b/runtime/engine/CMakeLists.txt @@ -0,0 +1,22 @@ +project(speechx LANGUAGES CXX) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common) + +add_subdirectory(kaldi) +add_subdirectory(common) + +if(WITH_ASR) + add_subdirectory(asr) +endif() + +if(WITH_CLS) + add_subdirectory(audio_classification) +endif() + +if(WITH_VAD) + add_subdirectory(vad) +endif() + +add_subdirectory(codelab) diff --git a/runtime/engine/asr/CMakeLists.txt b/runtime/engine/asr/CMakeLists.txt new file mode 100644 index 00000000..ff4cdecb --- /dev/null +++ b/runtime/engine/asr/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(ASR LANGUAGES CXX) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/server) + +add_subdirectory(decoder) +add_subdirectory(recognizer) +add_subdirectory(nnet) +add_subdirectory(server) diff --git a/runtime/engine/asr/decoder/CMakeLists.txt b/runtime/engine/asr/decoder/CMakeLists.txt new file mode 100644 index 00000000..2a20f446 --- /dev/null +++ b/runtime/engine/asr/decoder/CMakeLists.txt @@ -0,0 +1,24 @@ +set(srcs) +list(APPEND srcs + ctc_prefix_beam_search_decoder.cc + ctc_tlg_decoder.cc +) + +add_library(decoder STATIC ${srcs}) +target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder) + +# test +set(TEST_BINS + ctc_prefix_beam_search_decoder_main + ctc_tlg_decoder_main +) + +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl) +endforeach() + diff --git a/speechx/speechx/decoder/common.h b/runtime/engine/asr/decoder/common.h similarity index 100% rename from speechx/speechx/decoder/common.h rename to runtime/engine/asr/decoder/common.h diff --git a/speechx/speechx/decoder/ctc_beam_search_opt.h b/runtime/engine/asr/decoder/ctc_beam_search_opt.h similarity index 52% rename from speechx/speechx/decoder/ctc_beam_search_opt.h rename to runtime/engine/asr/decoder/ctc_beam_search_opt.h index f4a81b3a..4c145370 100644 --- a/speechx/speechx/decoder/ctc_beam_search_opt.h +++ b/runtime/engine/asr/decoder/ctc_beam_search_opt.h @@ -22,51 +22,22 @@ namespace ppspeech { struct CTCBeamSearchOptions { // common int blank; - - // ds2 - std::string dict_file; - std::string lm_path; - int beam_size; - BaseFloat alpha; - BaseFloat beta; - BaseFloat cutoff_prob; - int cutoff_top_n; - int num_proc_bsearch; + std::string word_symbol_table; // u2 int first_beam_size; int second_beam_size; + CTCBeamSearchOptions() : blank(0), - dict_file("vocab.txt"), - lm_path(""), - beam_size(300), - alpha(1.9f), - beta(5.0), - cutoff_prob(0.99f), - cutoff_top_n(40), - num_proc_bsearch(10), + word_symbol_table("vocab.txt"), first_beam_size(10), second_beam_size(10) {} void Register(kaldi::OptionsItf* opts) { - std::string module = "Ds2BeamSearchConfig: "; - opts->Register("dict", &dict_file, module + "vocab file path."); - opts->Register( - "lm-path", &lm_path, module + "ngram language model path."); - opts->Register("alpha", &alpha, module + "alpha"); - opts->Register("beta", &beta, module + "beta"); - opts->Register("beam-size", - &beam_size, - module + "beam size for beam search method"); - opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs"); - opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n"); - opts->Register( - "num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch"); - + std::string module = "CTCBeamSearchOptions: "; + opts->Register("word_symbol_table", &word_symbol_table, module + "vocab file path."); opts->Register("blank", &blank, "blank id, default is 0."); - - module = "U2BeamSearchConfig: "; opts->Register( "first-beam-size", &first_beam_size, module + "first beam size."); opts->Register("second-beam-size", diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc similarity index 96% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc index 07e8e560..bf912af2 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -17,13 +17,12 @@ #include "decoder/ctc_prefix_beam_search_decoder.h" -#include "absl/strings/str_join.h" #include "base/common.h" #include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_prefix_beam_search_score.h" #include "utils/math.h" -#ifdef USE_PROFILING +#ifdef WITH_PROFILING #include "paddle/fluid/platform/profiler.h" using paddle::platform::RecordEvent; using paddle::platform::TracerEventType; @@ -31,11 +30,10 @@ using paddle::platform::TracerEventType; namespace ppspeech { -CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path, - const CTCBeamSearchOptions& opts) +CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts) : opts_(opts) { unit_table_ = std::shared_ptr( - fst::SymbolTable::ReadText(vocab_path)); + fst::SymbolTable::ReadText(opts.word_symbol_table)); CHECK(unit_table_ != nullptr); Reset(); @@ -66,7 +64,6 @@ void CTCPrefixBeamSearch::Reset() { void CTCPrefixBeamSearch::InitDecoder() { Reset(); } - void CTCPrefixBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { double search_cost = 0.0; @@ -78,21 +75,21 @@ void CTCPrefixBeamSearch::AdvanceDecode( bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); feat_nnet_cost += timer.Elapsed(); if (flag == false) { - VLOG(3) << "decoder advance decode exit." << frame_prob.size(); + VLOG(2) << "decoder advance decode exit." << frame_prob.size(); break; } timer.Reset(); std::vector> likelihood; - likelihood.push_back(frame_prob); + likelihood.push_back(std::move(frame_prob)); AdvanceDecoding(likelihood); search_cost += timer.Elapsed(); - VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_; + VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_; } - VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost + VLOG(2) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost << " sec."; - VLOG(1) << "AdvanceDecode search cost: " << search_cost << " sec."; + VLOG(2) << "AdvanceDecode search cost: " << search_cost << " sec."; } static bool PrefixScoreCompare( @@ -105,7 +102,7 @@ static bool PrefixScoreCompare( void CTCPrefixBeamSearch::AdvanceDecoding( const std::vector>& logp) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding", TracerEventType::UserDefined, 1); diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h similarity index 94% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h index 5013246a..391b4073 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h @@ -27,8 +27,7 @@ namespace ppspeech { class ContextGraph; class CTCPrefixBeamSearch : public DecoderBase { public: - CTCPrefixBeamSearch(const std::string& vocab_path, - const CTCBeamSearchOptions& opts); + CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts); ~CTCPrefixBeamSearch() {} SearchType Type() const { return SearchType::kPrefixBeamSearch; } @@ -45,7 +44,7 @@ class CTCPrefixBeamSearch : public DecoderBase { void FinalizeSearch(); - const std::shared_ptr VocabTable() const { + const std::shared_ptr WordSymbolTable() const override { return unit_table_; } @@ -57,7 +56,6 @@ class CTCPrefixBeamSearch : public DecoderBase { } const std::vector>& Times() const { return times_; } - protected: std::string GetBestPath() override; std::vector> GetNBestPath() override; diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc similarity index 86% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index c59b1f2e..0935c6e6 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/strings/str_split.h" #include "base/common.h" #include "decoder/ctc_prefix_beam_search_decoder.h" -#include "frontend/audio/data_cache.h" +#include "frontend/data_cache.h" #include "fst/symbol-table.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" +#include "nnet/nnet_producer.h" #include "nnet/u2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(vocab_path, "", "vocab path"); +DEFINE_string(word_symbol_table, "", "vocab path"); DEFINE_string(model_path, "", "paddle nnet model"); @@ -40,7 +40,7 @@ using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; -// test ds2 online decoder by feeding speech feature +// test u2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -52,10 +52,10 @@ int main(int argc, char* argv[]) { CHECK_NE(FLAGS_result_wspecifier, ""); CHECK_NE(FLAGS_feature_rspecifier, ""); - CHECK_NE(FLAGS_vocab_path, ""); + CHECK_NE(FLAGS_word_symbol_table, ""); CHECK_NE(FLAGS_model_path, ""); LOG(INFO) << "model path: " << FLAGS_model_path; - LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path; + LOG(INFO) << "Reading vocab table " << FLAGS_word_symbol_table; kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); @@ -70,15 +70,18 @@ int main(int argc, char* argv[]) { // decodeable std::shared_ptr raw_data = std::make_shared(); + std::shared_ptr nnet_producer = + std::make_shared(nnet, raw_data, 1.0); std::shared_ptr decodable = - std::make_shared(nnet, raw_data); + std::make_shared(nnet_producer); // decoder ppspeech::CTCBeamSearchOptions opts; opts.blank = 0; opts.first_beam_size = 10; opts.second_beam_size = 10; - ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts); + opts.word_symbol_table = FLAGS_word_symbol_table; + ppspeech::CTCPrefixBeamSearch decoder(opts); int32 chunk_size = FLAGS_receptive_field_length + @@ -122,15 +125,14 @@ int main(int argc, char* argv[]) { } - kaldi::Vector feature_chunk(this_chunk_size * - feat_dim); + std::vector feature_chunk(this_chunk_size * + feat_dim); int32 start = chunk_idx * chunk_stride; for (int row_id = 0; row_id < this_chunk_size; ++row_id) { kaldi::SubVector feat_row(feature, start); - kaldi::SubVector feature_chunk_row( - feature_chunk.Data() + row_id * feat_dim, feat_dim); - - feature_chunk_row.CopyFromVec(feat_row); + std::memcpy(feature_chunk.data() + row_id * feat_dim, + feat_row.Data(), + feat_dim * sizeof(kaldi::BaseFloat)); ++start; } diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_score.h b/runtime/engine/asr/decoder/ctc_prefix_beam_search_score.h similarity index 100% rename from speechx/speechx/decoder/ctc_prefix_beam_search_score.h rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_score.h diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder.cc similarity index 62% rename from speechx/speechx/decoder/ctc_tlg_decoder.cc rename to runtime/engine/asr/decoder/ctc_tlg_decoder.cc index 2c2b6d3c..51ded499 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.cc @@ -13,12 +13,14 @@ // limitations under the License. #include "decoder/ctc_tlg_decoder.h" + namespace ppspeech { -TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { - fst_.reset(fst::Fst::Read(opts.fst_path)); +TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) { + fst_ = opts.fst_ptr; CHECK(fst_ != nullptr); + CHECK(!opts.word_symbol_table.empty()); word_symbol_table_.reset( fst::SymbolTable::ReadText(opts.word_symbol_table)); @@ -29,6 +31,11 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { void TLGDecoder::Reset() { decoder_->InitDecoding(); + hypotheses_.clear(); + likelihood_.clear(); + olabels_.clear(); + times_.clear(); + num_frame_decoded_ = 0; return; } @@ -68,14 +75,52 @@ std::string TLGDecoder::GetPartialResult() { return words; } +void TLGDecoder::FinalizeSearch() { + decoder_->FinalizeDecoding(); + kaldi::CompactLattice clat; + decoder_->GetLattice(&clat, true); + kaldi::Lattice lat, nbest_lat; + fst::ConvertLattice(clat, &lat); + fst::ShortestPath(lat, &nbest_lat, opts_.nbest); + std::vector nbest_lats; + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); + + hypotheses_.clear(); + hypotheses_.reserve(nbest_lats.size()); + likelihood_.clear(); + likelihood_.reserve(nbest_lats.size()); + times_.clear(); + times_.reserve(nbest_lats.size()); + for (auto lat : nbest_lats) { + kaldi::LatticeWeight weight; + std::vector hypothese; + std::vector time; + std::vector alignment; + std::vector words_id; + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + int idx = 0; + for (; idx < alignment.size() - 1; ++idx) { + if (alignment[idx] == 0) continue; + if (alignment[idx] != alignment[idx + 1]) { + hypothese.push_back(alignment[idx] - 1); + time.push_back(idx); // fake time, todo later + } + } + hypothese.push_back(alignment[idx] - 1); + time.push_back(idx); // fake time, todo later + hypotheses_.push_back(hypothese); + times_.push_back(time); + olabels_.push_back(words_id); + likelihood_.push_back(-(weight.Value2() + weight.Value1())); + } +} + std::string TLGDecoder::GetFinalBestPath() { if (num_frame_decoded_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // BestPathEnd if no frames were decoded.") return std::string(""); } - - decoder_->FinalizeDecoding(); kaldi::Lattice lat; kaldi::LatticeWeight weight; std::vector alignment; diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/runtime/engine/asr/decoder/ctc_tlg_decoder.h similarity index 67% rename from speechx/speechx/decoder/ctc_tlg_decoder.h rename to runtime/engine/asr/decoder/ctc_tlg_decoder.h index 8be69dad..80896361 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.h @@ -18,13 +18,14 @@ #include "decoder/decoder_itf.h" #include "kaldi/decoder/lattice-faster-online-decoder.h" #include "util/parse-options.h" +#include "utils/file_utils.h" - -DECLARE_string(graph_path); DECLARE_string(word_symbol_table); +DECLARE_string(graph_path); DECLARE_int32(max_active); DECLARE_double(beam); DECLARE_double(lattice_beam); +DECLARE_int32(nbest); namespace ppspeech { @@ -33,17 +34,27 @@ struct TLGDecoderOptions { // todo remove later, add into decode resource std::string word_symbol_table; std::string fst_path; + std::shared_ptr> fst_ptr; + int nbest; + + TLGDecoderOptions() : word_symbol_table(""), fst_path(""), fst_ptr(nullptr), nbest(10) {} static TLGDecoderOptions InitFromFlags() { TLGDecoderOptions decoder_opts; decoder_opts.word_symbol_table = FLAGS_word_symbol_table; decoder_opts.fst_path = FLAGS_graph_path; LOG(INFO) << "fst path: " << decoder_opts.fst_path; - LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table; + LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table; + + if (!decoder_opts.fst_path.empty()) { + CHECK(FileExists(decoder_opts.fst_path)); + decoder_opts.fst_ptr.reset(fst::Fst::Read(FLAGS_graph_path)); + } decoder_opts.opts.max_active = FLAGS_max_active; decoder_opts.opts.beam = FLAGS_beam; decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + decoder_opts.nbest = FLAGS_nbest; LOG(INFO) << "LatticeFasterDecoder max active: " << decoder_opts.opts.max_active; LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam; @@ -59,20 +70,38 @@ class TLGDecoder : public DecoderBase { explicit TLGDecoder(TLGDecoderOptions opts); ~TLGDecoder() = default; - void InitDecoder(); - void Reset(); + void InitDecoder() override; + void Reset() override; void AdvanceDecode( - const std::shared_ptr& decodable); + const std::shared_ptr& decodable) override; void Decode(); std::string GetFinalBestPath() override; std::string GetPartialResult() override; + const std::shared_ptr WordSymbolTable() const override { + return word_symbol_table_; + } + int DecodeLikelihoods(const std::vector>& probs, const std::vector& nbest_words); + void FinalizeSearch() override; + const std::vector>& Inputs() const override { + return hypotheses_; + } + const std::vector>& Outputs() const override { + return olabels_; + } // outputs_; } + const std::vector& Likelihood() const override { + return likelihood_; + } + const std::vector>& Times() const override { + return times_; + } + protected: std::string GetBestPath() override { CHECK(false); @@ -90,10 +119,17 @@ class TLGDecoder : public DecoderBase { private: void AdvanceDecoding(kaldi::DecodableInterface* decodable); + int num_frame_decoded_; + std::vector> hypotheses_; + std::vector> olabels_; + std::vector likelihood_; + std::vector> times_; + std::shared_ptr decoder_; std::shared_ptr> fst_; std::shared_ptr word_symbol_table_; + TLGDecoderOptions opts_; }; -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/decoder/nnet_logprob_decoder_main.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc similarity index 50% rename from speechx/speechx/decoder/nnet_logprob_decoder_main.cc rename to runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc index e0acbe77..dcd18b81 100644 --- a/speechx/speechx/decoder/nnet_logprob_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc @@ -14,21 +14,24 @@ // todo refactor, repalce with gtest -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_beam_search_decoder.h" +#include "base/common.h" +#include "decoder/ctc_tlg_decoder.h" +#include "decoder/param.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" +#include "nnet/nnet_producer.h" + + +DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "lm.klm", "language model"); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; -// test decoder by feeding nnet posterior probability +// test TLG decoder by feeding speech feature. int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -36,41 +39,51 @@ int main(int argc, char* argv[]) { google::InstallFailureSignalHandler(); FLAGS_logtostderr = 1; - kaldi::SequentialBaseFloatMatrixReader likelihood_reader( - FLAGS_nnet_prob_respecifier); - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "dict path: " << dict_file; - LOG(INFO) << "lm path: " << lm_path; + kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader( + FLAGS_nnet_prob_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); int32 num_done = 0, num_err = 0; - ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; - ppspeech::CTCBeamSearch decoder(opts); + ppspeech::TLGDecoderOptions opts = + ppspeech::TLGDecoderOptions::InitFromFlags(); + opts.opts.beam = 15.0; + opts.opts.lattice_beam = 7.5; + ppspeech::TLGDecoder decoder(opts); + + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + std::shared_ptr nnet_producer = + std::make_shared(nullptr, nullptr, 1.0); std::shared_ptr decodable( - new ppspeech::Decodable(nullptr, nullptr)); + new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale)); decoder.InitDecoder(); + kaldi::Timer timer; - for (; !likelihood_reader.Done(); likelihood_reader.Next()) { - string utt = likelihood_reader.Key(); - const kaldi::Matrix likelihood = likelihood_reader.Value(); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << likelihood.NumRows(); - LOG(INFO) << "cols: " << likelihood.NumCols(); - decodable->Acceptlikelihood(likelihood); + for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) { + string utt = nnet_prob_reader.Key(); + kaldi::Matrix prob = nnet_prob_reader.Value(); + decodable->Acceptlikelihood(prob); decoder.AdvanceDecode(decodable); std::string result; result = decoder.GetFinalBestPath(); - KALDI_LOG << " the result of " << utt << " is " << result; decodable->Reset(); decoder.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); ++num_done; } + double elapsed = timer.Elapsed(); + KALDI_LOG << " cost:" << elapsed << " s"; + KALDI_LOG << "Done " << num_done << " utterances, " << num_err << " with errors."; return (num_done != 0 ? 0 : 1); diff --git a/speechx/speechx/decoder/decoder_itf.h b/runtime/engine/asr/decoder/decoder_itf.h similarity index 79% rename from speechx/speechx/decoder/decoder_itf.h rename to runtime/engine/asr/decoder/decoder_itf.h index 2289b317..cb7717e8 100644 --- a/speechx/speechx/decoder/decoder_itf.h +++ b/runtime/engine/asr/decoder/decoder_itf.h @@ -1,4 +1,3 @@ - // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +15,7 @@ #pragma once #include "base/common.h" +#include "fst/symbol-table.h" #include "kaldi/decoder/decodable-itf.h" namespace ppspeech { @@ -41,6 +41,14 @@ class DecoderInterface { virtual std::string GetPartialResult() = 0; + virtual const std::shared_ptr WordSymbolTable() const = 0; + virtual void FinalizeSearch() = 0; + + virtual const std::vector>& Inputs() const = 0; + virtual const std::vector>& Outputs() const = 0; + virtual const std::vector& Likelihood() const = 0; + virtual const std::vector>& Times() const = 0; + protected: // virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0; diff --git a/speechx/speechx/decoder/param.h b/runtime/engine/asr/decoder/param.h similarity index 73% rename from speechx/speechx/decoder/param.h rename to runtime/engine/asr/decoder/param.h index ebdd7119..0cad75bf 100644 --- a/speechx/speechx/decoder/param.h +++ b/runtime/engine/asr/decoder/param.h @@ -15,8 +15,6 @@ #pragma once #include "base/common.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "decoder/ctc_tlg_decoder.h" // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); @@ -37,36 +35,22 @@ DEFINE_int32(subsampling_rate, "two CNN(kernel=3) module downsampling rate."); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); - // nnet -DEFINE_string(vocab_path, "", "nnet vocab path."); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); -DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", - "model output names"); -DEFINE_string(model_cache_names, - "chunk_state_h_box,chunk_state_c_box", - "model cache names"); -DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); - +#ifdef USE_ONNX +DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path"); +#endif // decoder DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); - -DEFINE_string(graph_path, "TLG", "decoder graph"); -DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "", "decoder graph"); +DEFINE_string(word_symbol_table, "", "word symbol table"); DEFINE_int32(max_active, 7500, "max active"); DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam"); - +DEFINE_double(blank_threshold, 0.98, "blank skip threshold"); // DecodeOptions flags -// DEFINE_int32(chunk_size, -1, "decoding chunk size"); DEFINE_int32(num_left_chunks, -1, "left chunks in decoding"); DEFINE_double(ctc_weight, 0.5, diff --git a/runtime/engine/asr/nnet/CMakeLists.txt b/runtime/engine/asr/nnet/CMakeLists.txt new file mode 100644 index 00000000..1adcbfeb --- /dev/null +++ b/runtime/engine/asr/nnet/CMakeLists.txt @@ -0,0 +1,21 @@ +set(srcs decodable.cc nnet_producer.cc) + +list(APPEND srcs u2_nnet.cc) +if(WITH_ONNX) + list(APPEND srcs u2_onnx_nnet.cc) +endif() +add_library(nnet STATIC ${srcs}) +target_link_libraries(nnet utils) +if(WITH_ONNX) + target_link_libraries(nnet ${FASTDEPLOY_LIBS}) +endif() + +target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) +target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + +# test bin +#set(bin_name u2_nnet_main) +#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) \ No newline at end of file diff --git a/speechx/speechx/nnet/decodable.cc b/runtime/engine/asr/nnet/decodable.cc similarity index 54% rename from speechx/speechx/nnet/decodable.cc rename to runtime/engine/asr/nnet/decodable.cc index 5fe2b984..a140c376 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/runtime/engine/asr/nnet/decodable.cc @@ -21,29 +21,25 @@ using kaldi::Matrix; using kaldi::Vector; using std::vector; -Decodable::Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend, +Decodable::Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale) - : frontend_(frontend), - nnet_(nnet), + : nnet_producer_(nnet_producer), frame_offset_(0), frames_ready_(0), acoustic_scale_(acoustic_scale) {} // for debug void Decodable::Acceptlikelihood(const Matrix& likelihood) { - nnet_out_cache_ = likelihood; - frames_ready_ += likelihood.NumRows(); + nnet_producer_->Acceptlikelihood(likelihood); } - // return the size of frame have computed. int32 Decodable::NumFramesReady() const { return frames_ready_; } // frame idx is from 0 to frame_ready_ -1; bool Decodable::IsLastFrame(int32 frame) { - bool flag = EnsureFrameHaveComputed(frame); + EnsureFrameHaveComputed(frame); return frame >= frames_ready_; } @@ -64,32 +60,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::AdvanceChunk() { kaldi::Timer timer; - // read feats - Vector features; - if (frontend_ == NULL || frontend_->Read(&features) == false) { - // no feat or frontend_ not init. - VLOG(3) << "decodable exit;"; - return false; - } - CHECK_GE(frontend_->Dim(), 0); - VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec."; - VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats."; - - // forward feats - NnetOut out; - nnet_->FeedForward(features, frontend_->Dim(), &out); - int32& vocab_dim = out.vocab_dim; - Vector& logprobs = out.logprobs; - - VLOG(2) << "Forward out " << logprobs.Dim() / vocab_dim - << " decoder frames."; - // cache nnet outupts - nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim); - nnet_out_cache_.CopyRowsFromVec(logprobs); - - // update state, decoding frame. + bool flag = nnet_producer_->Read(&framelikelihood_); + if (flag == false) return false; frame_offset_ = frames_ready_; - frames_ready_ += nnet_out_cache_.NumRows(); + frames_ready_ += 1; VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed() << " sec."; return true; @@ -101,17 +75,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, return false; } - int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - if (nrows <= 0) { + if (framelikelihood_.empty()) { LOG(WARNING) << "No new nnet out in cache."; return false; } - logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols()); - logprobs->CopyRowsFromMat(nnet_out_cache_); - - *vocab_dim = nnet_out_cache_.NumCols(); + size_t dim = framelikelihood_.size(); + logprobs->Resize(framelikelihood_.size()); + std::memcpy(logprobs->Data(), + framelikelihood_.data(), + dim * sizeof(kaldi::BaseFloat)); + *vocab_dim = framelikelihood_.size(); return true; } @@ -122,19 +96,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) { return false; } - int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - int vocab_size = nnet_out_cache_.NumCols(); - likelihood->resize(vocab_size); - - for (int32 idx = 0; idx < vocab_size; ++idx) { - (*likelihood)[idx] = - nnet_out_cache_(frame - frame_offset_, idx) * acoustic_scale_; - - VLOG(4) << "nnet out: " << frame << " offset:" << frame_offset_ << " " - << nnet_out_cache_.NumRows() - << " logprob: " << nnet_out_cache_(frame - frame_offset_, idx); - } + CHECK_EQ(1, (frames_ready_ - frame_offset_)); + *likelihood = framelikelihood_; return true; } @@ -143,37 +106,31 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return false; } - CHECK_LE(index, nnet_out_cache_.NumCols()); + CHECK_LE(index, framelikelihood_.size()); CHECK_LE(frame, frames_ready_); // the nnet output is prob ranther than log prob // the index - 1, because the ilabel BaseFloat logprob = 0.0; int32 frame_idx = frame - frame_offset_; - BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index)); - if (nnet_->IsLogProb()) { - logprob = nnet_out; - } else { - logprob = std::log(nnet_out + std::numeric_limits::epsilon()); - } - CHECK(!std::isnan(logprob) && !std::isinf(logprob)); + CHECK_EQ(frame_idx, 0); + logprob = framelikelihood_[TokenId2NnetId(index)]; return acoustic_scale_ * logprob; } void Decodable::Reset() { - if (frontend_ != nullptr) frontend_->Reset(); - if (nnet_ != nullptr) nnet_->Reset(); + if (nnet_producer_ != nullptr) nnet_producer_->Reset(); frame_offset_ = 0; frames_ready_ = 0; - nnet_out_cache_.Resize(0, 0); + framelikelihood_.clear(); } void Decodable::AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) { kaldi::Timer timer; - nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); + nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score); VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec."; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/nnet/decodable.h b/runtime/engine/asr/nnet/decodable.h similarity index 81% rename from speechx/speechx/nnet/decodable.h rename to runtime/engine/asr/nnet/decodable.h index dd7b329e..f6448670 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/runtime/engine/asr/nnet/decodable.h @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once + #include "base/common.h" -#include "frontend/audio/frontend_itf.h" #include "kaldi/decoder/decodable-itf.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" +#include "nnet/nnet_producer.h" namespace ppspeech { @@ -24,12 +26,9 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - explicit Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend, + explicit Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale = 1.0); - // void Init(DecodableOpts config); - // nnet logprob output, used by wfst virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); @@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface { void Reset(); - bool IsInputFinished() const { return frontend_->IsFinished(); } + bool IsInputFinished() const { return nnet_producer_->IsFinished(); } bool EnsureFrameHaveComputed(int32 frame); int32 TokenId2NnetId(int32 token_id); - std::shared_ptr Nnet() { return nnet_; } - // for offline test void Acceptlikelihood(const kaldi::Matrix& likelihood); private: - std::shared_ptr frontend_; - std::shared_ptr nnet_; - - // nnet outputs' cache - kaldi::Matrix nnet_out_cache_; + std::shared_ptr nnet_producer_; // the frame is nnet prob frame rather than audio feature frame // nnet frame subsample the feature frame @@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface { // so use subsampled_frame int32 current_log_post_subsampled_offset_; int32 num_chunk_computed_; + std::vector framelikelihood_; kaldi::BaseFloat acoustic_scale_; }; diff --git a/speechx/speechx/nnet/nnet_itf.h b/runtime/engine/asr/nnet/nnet_itf.h similarity index 70% rename from speechx/speechx/nnet/nnet_itf.h rename to runtime/engine/asr/nnet/nnet_itf.h index a504cce5..ac105d11 100644 --- a/speechx/speechx/nnet/nnet_itf.h +++ b/runtime/engine/asr/nnet/nnet_itf.h @@ -15,7 +15,6 @@ #include "base/basic_types.h" #include "kaldi/base/kaldi-types.h" -#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/util/options-itf.h" DECLARE_int32(subsampling_rate); @@ -25,26 +24,20 @@ DECLARE_string(model_input_names); DECLARE_string(model_output_names); DECLARE_string(model_cache_names); DECLARE_string(model_cache_shapes); +#ifdef USE_ONNX +DECLARE_bool(with_onnx_model); +#endif namespace ppspeech { struct ModelOptions { // common int subsample_rate{1}; - int thread_num{1}; // predictor thread pool size for ds2; bool use_gpu{false}; std::string model_path; - - std::string param_path; - - // ds2 for inference - std::string input_names{}; - std::string output_names{}; - std::string cache_names{}; - std::string cache_shape{}; - bool switch_ir_optim{false}; - bool enable_fc_padding{false}; - bool enable_profile{false}; +#ifdef USE_ONNX + bool with_onnx_model{false}; +#endif static ModelOptions InitFromFlags() { ModelOptions opts; @@ -52,26 +45,17 @@ struct ModelOptions { LOG(INFO) << "subsampling rate: " << opts.subsample_rate; opts.model_path = FLAGS_model_path; LOG(INFO) << "model path: " << opts.model_path; - - opts.param_path = FLAGS_param_path; - LOG(INFO) << "param path: " << opts.param_path; - - LOG(INFO) << "DS2 param: "; - opts.cache_names = FLAGS_model_cache_names; - LOG(INFO) << " cache names: " << opts.cache_names; - opts.cache_shape = FLAGS_model_cache_shapes; - LOG(INFO) << " cache shape: " << opts.cache_shape; - opts.input_names = FLAGS_model_input_names; - LOG(INFO) << " input names: " << opts.input_names; - opts.output_names = FLAGS_model_output_names; - LOG(INFO) << " output names: " << opts.output_names; +#ifdef USE_ONNX + opts.with_onnx_model = FLAGS_with_onnx_model; + LOG(INFO) << "with onnx model: " << opts.with_onnx_model; +#endif return opts; } }; struct NnetOut { // nnet out. maybe logprob or prob. Almost time this is logprob. - kaldi::Vector logprobs; + std::vector logprobs; int32 vocab_dim; // nnet state. Only using in Attention model. @@ -89,7 +73,7 @@ class NnetInterface { // nnet do not cache feats, feats cached by frontend. // nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache, // frame_offset. - virtual void FeedForward(const kaldi::Vector& features, + virtual void FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) = 0; @@ -105,14 +89,14 @@ class NnetInterface { // using to get encoder outs. e.g. seq2seq with Attention model. virtual void EncoderOuts( - std::vector>* encoder_out) const = 0; + std::vector>* encoder_out) const = 0; }; class NnetBase : public NnetInterface { public: int SubsamplingRate() const { return subsampling_rate_; } - + virtual std::shared_ptr Clone() const = 0; protected: int subsampling_rate_{1}; }; diff --git a/runtime/engine/asr/nnet/nnet_producer.cc b/runtime/engine/asr/nnet/nnet_producer.cc new file mode 100644 index 00000000..529fae65 --- /dev/null +++ b/runtime/engine/asr/nnet/nnet_producer.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet_producer.h" + +#include "matrix/kaldi-matrix.h" + +namespace ppspeech { + +using kaldi::BaseFloat; +using std::vector; + +NnetProducer::NnetProducer(std::shared_ptr nnet, + std::shared_ptr frontend, + float blank_threshold) + : nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) { + Reset(); +} + +void NnetProducer::Accept(const std::vector& inputs) { + frontend_->Accept(inputs); +} + +void NnetProducer::Acceptlikelihood( + const kaldi::Matrix& likelihood) { + std::vector prob; + prob.resize(likelihood.NumCols()); + for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { + for (size_t col = 0; col < likelihood.NumCols(); ++col) { + prob[col] = likelihood(idx, col); + } + cache_.push_back(prob); + } +} + +bool NnetProducer::Read(std::vector* nnet_prob) { + bool flag = cache_.pop(nnet_prob); + return flag; +} + +bool NnetProducer::Compute() { + vector features; + if (frontend_ == NULL || frontend_->Read(&features) == false) { + // no feat or frontend_ not init. + if (frontend_->IsFinished() == true) { + finished_ = true; + } + return false; + } + CHECK_GE(frontend_->Dim(), 0); + VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats."; + + NnetOut out; + nnet_->FeedForward(features, frontend_->Dim(), &out); + int32& vocab_dim = out.vocab_dim; + size_t nframes = out.logprobs.size() / vocab_dim; + VLOG(1) << "Forward out " << nframes << " decoder frames."; + for (size_t idx = 0; idx < nframes; ++idx) { + std::vector logprob( + out.logprobs.data() + idx * vocab_dim, + out.logprobs.data() + (idx + 1) * vocab_dim); + // process blank prob + float blank_prob = std::exp(logprob[0]); + if (blank_prob > blank_threshold_) { + last_frame_logprob_ = logprob; + is_last_frame_skip_ = true; + continue; + } else { + int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin(); + if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) { + cache_.push_back(last_frame_logprob_); + last_max_elem_ = cur_max; + } + last_max_elem_ = cur_max; + is_last_frame_skip_ = false; + cache_.push_back(logprob); + } + } + return true; +} + +void NnetProducer::AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score) { + nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); +} + +} // namespace ppspeech diff --git a/runtime/engine/asr/nnet/nnet_producer.h b/runtime/engine/asr/nnet/nnet_producer.h new file mode 100644 index 00000000..21aee067 --- /dev/null +++ b/runtime/engine/asr/nnet/nnet_producer.h @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "base/safe_queue.h" +#include "frontend/frontend_itf.h" +#include "nnet/nnet_itf.h" + +namespace ppspeech { + +class NnetProducer { + public: + explicit NnetProducer(std::shared_ptr nnet, + std::shared_ptr frontend, + float blank_threshold); + // Feed feats or waves + void Accept(const std::vector& inputs); + + void Acceptlikelihood(const kaldi::Matrix& likelihood); + + // nnet + bool Read(std::vector* nnet_prob); + + bool Empty() const { return cache_.empty(); } + + void SetInputFinished() { + LOG(INFO) << "set finished"; + frontend_->SetFinished(); + } + + // the compute thread exit + bool IsFinished() const { + return (frontend_->IsFinished() && finished_); + } + + ~NnetProducer() {} + + void Reset() { + if (frontend_ != NULL) frontend_->Reset(); + if (nnet_ != NULL) nnet_->Reset(); + cache_.clear(); + finished_ = false; + } + + void AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score); + + bool Compute(); + private: + + std::shared_ptr frontend_; + std::shared_ptr nnet_; + SafeQueue> cache_; + std::vector last_frame_logprob_; + bool is_last_frame_skip_ = false; + int last_max_elem_ = -1; + float blank_threshold_ = 0.0; + bool finished_; + + DISALLOW_COPY_AND_ASSIGN(NnetProducer); +}; + +} // namespace ppspeech diff --git a/speechx/speechx/nnet/u2_nnet.cc b/runtime/engine/asr/nnet/u2_nnet.cc similarity index 87% rename from speechx/speechx/nnet/u2_nnet.cc rename to runtime/engine/asr/nnet/u2_nnet.cc index 7707406a..9a09514e 100644 --- a/speechx/speechx/nnet/u2_nnet.cc +++ b/runtime/engine/asr/nnet/u2_nnet.cc @@ -17,12 +17,13 @@ // https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc #include "nnet/u2_nnet.h" +#include -#ifdef USE_PROFILING +#ifdef WITH_PROFILING #include "paddle/fluid/platform/profiler.h" using paddle::platform::RecordEvent; using paddle::platform::TracerEventType; -#endif // end USE_PROFILING +#endif // end WITH_PROFILING namespace ppspeech { @@ -30,7 +31,7 @@ namespace ppspeech { void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { paddle::jit::utils::InitKernelSignatureMap(); -#ifdef USE_GPU +#ifdef WITH_GPU dev_ = phi::GPUPlace(); #else dev_ = phi::CPUPlace(); @@ -62,12 +63,12 @@ void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { } void U2Nnet::Warmup() { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("warmup", TracerEventType::UserDefined, 1); #endif { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event( "warmup-encoder-ctc", TracerEventType::UserDefined, 1); #endif @@ -91,7 +92,7 @@ void U2Nnet::Warmup() { } { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("warmup-decoder", TracerEventType::UserDefined, 1); #endif auto hyps = @@ -101,10 +102,10 @@ void U2Nnet::Warmup() { auto encoder_out = paddle::ones( {1, 20, 512}, paddle::DataType::FLOAT32, phi::CPUPlace()); - std::vector inputs{ + std::vector inputs{ hyps, hyps_lens, encoder_out}; - std::vector outputs = + std::vector outputs = forward_attention_decoder_(inputs); } @@ -118,27 +119,46 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) { // shallow copy U2Nnet::U2Nnet(const U2Nnet& other) { // copy meta - right_context_ = other.right_context_; - subsampling_rate_ = other.subsampling_rate_; - sos_ = other.sos_; - eos_ = other.eos_; - is_bidecoder_ = other.is_bidecoder_; chunk_size_ = other.chunk_size_; num_left_chunks_ = other.num_left_chunks_; - - forward_encoder_chunk_ = other.forward_encoder_chunk_; - forward_attention_decoder_ = other.forward_attention_decoder_; - ctc_activation_ = other.ctc_activation_; - offset_ = other.offset_; // copy model ptr - model_ = other.model_; + // model_ = other.model_->Clone(); + // hack, fix later + #ifdef WITH_GPU + dev_ = phi::GPUPlace(); + #else + dev_ = phi::CPUPlace(); + #endif + paddle::jit::Layer model = paddle::jit::Load(other.opts_.model_path, dev_); + model_ = std::make_shared(std::move(model)); + ctc_activation_ = model_->Function("ctc_activation"); + subsampling_rate_ = model_->Attribute("subsampling_rate"); + right_context_ = model_->Attribute("right_context"); + sos_ = model_->Attribute("sos_symbol"); + eos_ = model_->Attribute("eos_symbol"); + is_bidecoder_ = model_->Attribute("is_bidirectional_decoder"); + + forward_encoder_chunk_ = model_->Function("forward_encoder_chunk"); + forward_attention_decoder_ = model_->Function("forward_attention_decoder"); + ctc_activation_ = model_->Function("ctc_activation"); + CHECK(forward_encoder_chunk_.IsValid()); + CHECK(forward_attention_decoder_.IsValid()); + CHECK(ctc_activation_.IsValid()); + + LOG(INFO) << "Paddle Model Info: "; + LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_; + LOG(INFO) << "\tright context " << right_context_; + LOG(INFO) << "\tsos " << sos_; + LOG(INFO) << "\teos " << eos_; + LOG(INFO) << "\tis bidecoder " << is_bidecoder_ << std::endl; + // ignore inner states } -std::shared_ptr U2Nnet::Copy() const { +std::shared_ptr U2Nnet::Clone() const { auto asr_model = std::make_shared(*this); // reset inner state for new decoding asr_model->Reset(); @@ -154,6 +174,7 @@ void U2Nnet::Reset() { std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32)); encoder_outs_.clear(); + VLOG(1) << "FeedForward cost: " << cost_time_ << " sec. "; VLOG(3) << "u2nnet reset"; } @@ -165,23 +186,18 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) { } -void U2Nnet::FeedForward(const kaldi::Vector& features, +void U2Nnet::FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) { kaldi::Timer timer; - std::vector chunk_feats(features.Data(), - features.Data() + features.Dim()); std::vector ctc_probs; ForwardEncoderChunkImpl( - chunk_feats, feature_dim, &ctc_probs, &out->vocab_dim); - - out->logprobs.Resize(ctc_probs.size(), kaldi::kSetZero); - std::memcpy(out->logprobs.Data(), - ctc_probs.data(), - ctc_probs.size() * sizeof(kaldi::BaseFloat)); - VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. " - << chunk_feats.size() / feature_dim << " frames."; + features, feature_dim, &out->logprobs, &out->vocab_dim); + float forward_chunk_time = timer.Elapsed(); + VLOG(1) << "FeedForward cost: " << forward_chunk_time << " sec. " + << features.size() / feature_dim << " frames."; + cost_time_ += forward_chunk_time; } @@ -190,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl( const int32& feat_dim, std::vector* out_prob, int32* vocab_dim) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event( "ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1); #endif @@ -210,7 +226,7 @@ void U2Nnet::ForwardEncoderChunkImpl( // not cache feature in nnet CHECK_EQ(cached_feats_.size(), 0); - // CHECK_EQ(std::is_same::value, true); + CHECK_EQ((std::is_same::value), true); std::memcpy(feats_ptr, chunk_feats.data(), chunk_feats.size() * sizeof(kaldi::BaseFloat)); @@ -218,7 +234,7 @@ void U2Nnet::ForwardEncoderChunkImpl( VLOG(3) << "feats shape: " << feats.shape()[0] << ", " << feats.shape()[1] << ", " << feats.shape()[2]; -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("feat", std::ios_base::app | std::ios_base::out); path << offset_; @@ -237,7 +253,7 @@ void U2Nnet::ForwardEncoderChunkImpl( #endif // Endocer chunk forward -#ifdef USE_GPU +#ifdef WITH_GPU feats = feats.copy_to(paddle::GPUPlace(), /*blocking*/ false); att_cache_ = att_cache_.copy_to(paddle::GPUPlace()), /*blocking*/ false; cnn_cache_ = cnn_cache_.copy_to(Paddle::GPUPlace(), /*blocking*/ false); @@ -254,7 +270,7 @@ void U2Nnet::ForwardEncoderChunkImpl( std::vector outputs = forward_encoder_chunk_(inputs); CHECK_EQ(outputs.size(), 3); -#ifdef USE_GPU +#ifdef WITH_GPU paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace()); att_cache_ = outputs[1].copy_to(paddle::CPUPlace()); cnn_cache_ = outputs[2].copy_to(paddle::CPUPlace()); @@ -264,7 +280,7 @@ void U2Nnet::ForwardEncoderChunkImpl( cnn_cache_ = outputs[2]; #endif -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logits", std::ios_base::app | std::ios_base::out); @@ -294,7 +310,7 @@ void U2Nnet::ForwardEncoderChunkImpl( encoder_outs_.push_back(chunk_out); VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size(); -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logits_list", std::ios_base::app | std::ios_base::out); @@ -313,7 +329,7 @@ void U2Nnet::ForwardEncoderChunkImpl( } #endif // end TEST_DEBUG -#ifdef USE_GPU +#ifdef WITH_GPU #error "Not implementation." @@ -327,7 +343,7 @@ void U2Nnet::ForwardEncoderChunkImpl( CHECK_EQ(outputs.size(), 1); paddle::Tensor ctc_log_probs = outputs[0]; -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logprob", std::ios_base::app | std::ios_base::out); @@ -349,7 +365,7 @@ void U2Nnet::ForwardEncoderChunkImpl( } #endif // end TEST_DEBUG -#endif // end USE_GPU +#endif // end WITH_GPU // Copy to output, (B=1,T,D) std::vector ctc_log_probs_shape = ctc_log_probs.shape(); @@ -366,7 +382,7 @@ void U2Nnet::ForwardEncoderChunkImpl( std::memcpy( out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat)); -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logits_list_ctc", std::ios_base::app | std::ios_base::out); @@ -415,7 +431,7 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob, void U2Nnet::AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1); #endif CHECK(rescoring_score != nullptr); @@ -457,7 +473,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } } -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logits_concat", std::ios_base::app | std::ios_base::out); @@ -481,7 +497,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, paddle::Tensor encoder_out = paddle::concat(encoder_outs_, 1); VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size(); -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_out0", std::ios_base::app | std::ios_base::out); @@ -500,7 +516,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_out", std::ios_base::app | std::ios_base::out); @@ -519,7 +535,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG - std::vector inputs{ + std::vector inputs{ hyps_tensor, hyps_lens, encoder_out}; std::vector outputs = forward_attention_decoder_(inputs); CHECK_EQ(outputs.size(), 2); @@ -531,7 +547,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, CHECK_EQ(probs_shape[0], num_hyps); CHECK_EQ(probs_shape[1], max_hyps_len); -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("decoder_logprob", std::ios_base::app | std::ios_base::out); @@ -549,7 +565,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("hyps_lens", std::ios_base::app | std::ios_base::out); @@ -565,7 +581,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("hyps_tensor", std::ios_base::app | std::ios_base::out); @@ -590,7 +606,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } else { // dump r_probs CHECK_EQ(r_probs_shape.size(), 1); - CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0]; + //CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0]; } // compute rescoring score @@ -600,15 +616,15 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, VLOG(2) << "split prob: " << probs_v.size() << " " << probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0] << ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2]; - CHECK(static_cast(probs_v.size()) == num_hyps) - << ": is " << probs_v.size() << " expect: " << num_hyps; + //CHECK(static_cast(probs_v.size()) == num_hyps) + // << ": is " << probs_v.size() << " expect: " << num_hyps; std::vector r_probs_v; if (is_bidecoder_ && reverse_weight > 0) { r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0); - CHECK(static_cast(r_probs_v.size()) == num_hyps) - << "r_probs_v size: is " << r_probs_v.size() - << " expect: " << num_hyps; + //CHECK(static_cast(r_probs_v.size()) == num_hyps) + // << "r_probs_v size: is " << r_probs_v.size() + // << " expect: " << num_hyps; } for (int i = 0; i < num_hyps; ++i) { @@ -638,7 +654,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, void U2Nnet::EncoderOuts( - std::vector>* encoder_out) const { + std::vector>* encoder_out) const { // list of (B=1,T,D) int size = encoder_outs_.size(); VLOG(3) << "encoder_outs_ size: " << size; @@ -650,18 +666,18 @@ void U2Nnet::EncoderOuts( const int& B = shape[0]; const int& T = shape[1]; const int& D = shape[2]; - CHECK(B == 1) << "Only support batch one."; + //CHECK(B == 1) << "Only support batch one."; VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << "," << D << ")"; const float* this_tensor_ptr = item.data(); for (int j = 0; j < T; j++) { const float* cur = this_tensor_ptr + j * D; - kaldi::Vector out(D); - std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat)); + std::vector out(D); + std::memcpy(out.data(), cur, D * sizeof(kaldi::BaseFloat)); encoder_out->emplace_back(out); } } } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/nnet/u2_nnet.h b/runtime/engine/asr/nnet/u2_nnet.h similarity index 91% rename from speechx/speechx/nnet/u2_nnet.h rename to runtime/engine/asr/nnet/u2_nnet.h index 23cc0ea3..dba5c55e 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/runtime/engine/asr/nnet/u2_nnet.h @@ -18,7 +18,7 @@ #pragma once #include "base/common.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" #include "paddle/extension.h" #include "paddle/jit/all.h" @@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase { num_left_chunks_ = num_left_chunks; } - virtual std::shared_ptr Copy() const = 0; + virtual std::shared_ptr Clone() const = 0; protected: virtual void ForwardEncoderChunkImpl( @@ -76,7 +76,7 @@ class U2Nnet : public U2NnetBase { explicit U2Nnet(const ModelOptions& opts); U2Nnet(const U2Nnet& other); - void FeedForward(const kaldi::Vector& features, + void FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) override; @@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase { std::shared_ptr model() const { return model_; } - std::shared_ptr Copy() const override; + std::shared_ptr Clone() const override; void ForwardEncoderChunkImpl( const std::vector& chunk_feats, @@ -111,10 +111,10 @@ class U2Nnet : public U2NnetBase { void FeedEncoderOuts(const paddle::Tensor& encoder_out); void EncoderOuts( - std::vector>* encoder_out) const; + std::vector>* encoder_out) const; + ModelOptions opts_; // hack, fix later private: - ModelOptions opts_; phi::Place dev_; std::shared_ptr model_{nullptr}; @@ -127,6 +127,7 @@ class U2Nnet : public U2NnetBase { paddle::jit::Function forward_encoder_chunk_; paddle::jit::Function forward_attention_decoder_; paddle::jit::Function ctc_activation_; + float cost_time_ = 0.0; }; } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/u2_nnet_main.cc b/runtime/engine/asr/nnet/u2_nnet_main.cc similarity index 99% rename from speechx/speechx/nnet/u2_nnet_main.cc rename to runtime/engine/asr/nnet/u2_nnet_main.cc index 53fc5554..e60ae7e8 100644 --- a/speechx/speechx/nnet/u2_nnet_main.cc +++ b/runtime/engine/asr/nnet/u2_nnet_main.cc @@ -15,8 +15,8 @@ #include "base/common.h" #include "decoder/param.h" -#include "frontend/audio/assembler.h" -#include "frontend/audio/data_cache.h" +#include "frontend/assembler.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/u2_nnet.h" diff --git a/runtime/engine/asr/nnet/u2_nnet_thread_main.cc b/runtime/engine/asr/nnet/u2_nnet_thread_main.cc new file mode 100644 index 00000000..008dbb1e --- /dev/null +++ b/runtime/engine/asr/nnet/u2_nnet_thread_main.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef USE_ONNX + #include "nnet/u2_nnet.h" +#else + #include "nnet/u2_onnx_nnet.h" +#endif +#include "base/common.h" +#include "decoder/param.h" +#include "frontend/feature_pipeline.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/nnet_producer.h" +#include "nnet/u2_nnet.h" + +DEFINE_string(wav_rspecifier, "", "test wav rspecifier"); +DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + + CHECK_GT(FLAGS_wav_rspecifier.size(), 0); + CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0); + CHECK_GT(FLAGS_model_path.size(), 0); + LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier; + LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; + LOG(INFO) << "model path: " << FLAGS_model_path; + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); + + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + ppspeech::FeaturePipelineOptions feature_opts = + ppspeech::FeaturePipelineOptions::InitFromFlags(); + feature_opts.assembler_opts.fill_zero = false; + +#ifndef USE_ONNX + std::shared_ptr nnet(new ppspeech::U2Nnet(model_opts)); +#else + std::shared_ptr nnet(new ppspeech::U2OnnxNnet(model_opts)); +#endif + std::shared_ptr feature_pipeline( + new ppspeech::FeaturePipeline(feature_opts)); + std::shared_ptr nnet_producer( + new ppspeech::NnetProducer(nnet, feature_pipeline)); + kaldi::Timer timer; + float tot_wav_duration = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + nnet_producer->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + nnet_producer->SetInputFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + + std::vector> prob_vec; + while (1) { + std::vector logprobs; + bool isok = nnet_producer->Read(&logprobs); + if (nnet_producer->IsFinished()) break; + if (isok == false) continue; + prob_vec.push_back(logprobs); + } + { + // writer nnet output + kaldi::MatrixIndexT nrow = prob_vec.size(); + kaldi::MatrixIndexT ncol = prob_vec[0].size(); + LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol; + kaldi::Matrix nnet_out(nrow, ncol); + for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { + for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { + nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx]; + } + } + nnet_out_writer.Write(utt, nnet_out); + } + nnet_producer->Reset(); + } + + nnet_producer->Wait(); + double elapsed = timer.Elapsed(); + LOG(INFO) << "Program cost:" << elapsed << " sec"; + + LOG(INFO) << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/runtime/engine/asr/nnet/u2_onnx_nnet.cc b/runtime/engine/asr/nnet/u2_onnx_nnet.cc new file mode 100644 index 00000000..d5e2fdb6 --- /dev/null +++ b/runtime/engine/asr/nnet/u2_onnx_nnet.cc @@ -0,0 +1,414 @@ +// Copyright 2022 Horizon Robotics. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from +// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.cc + +#include "nnet/u2_onnx_nnet.h" +#include "common/base/config.h" + +namespace ppspeech { + +void U2OnnxNnet::LoadModel(const std::string& model_dir) { + std::string encoder_onnx_path = model_dir + "/encoder.onnx"; + std::string rescore_onnx_path = model_dir + "/decoder.onnx"; + std::string ctc_onnx_path = model_dir + "/ctc.onnx"; + std::string param_path = model_dir + "/param.onnx"; + // 1. Load sessions + try { + encoder_ = std::make_shared(); + ctc_ = std::make_shared(); + rescore_ = std::make_shared(); + fastdeploy::RuntimeOption runtime_option; + runtime_option.UseOrtBackend(); + runtime_option.UseCpu(); + runtime_option.SetCpuThreadNum(1); + runtime_option.SetModelPath(encoder_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX); + assert(encoder_->Init(runtime_option)); + runtime_option.SetModelPath(rescore_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX); + assert(rescore_->Init(runtime_option)); + runtime_option.SetModelPath(ctc_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX); + assert(ctc_->Init(runtime_option)); + } catch (std::exception const& e) { + LOG(ERROR) << "error when load onnx model: " << e.what(); + exit(0); + } + + Config conf(param_path); + encoder_output_size_ = conf.Read("output_size", encoder_output_size_); + num_blocks_ = conf.Read("num_blocks", num_blocks_); + head_ = conf.Read("head", head_); + cnn_module_kernel_ = conf.Read("cnn_module_kernel", cnn_module_kernel_); + subsampling_rate_ = conf.Read("subsampling_rate", subsampling_rate_); + right_context_ = conf.Read("right_context", right_context_); + sos_= conf.Read("sos_symbol", sos_); + eos_= conf.Read("eos_symbol", eos_); + is_bidecoder_= conf.Read("is_bidirectional_decoder", is_bidecoder_); + chunk_size_= conf.Read("chunk_size", chunk_size_); + num_left_chunks_ = conf.Read("left_chunks", num_left_chunks_); + + LOG(INFO) << "Onnx Model Info:"; + LOG(INFO) << "\tencoder_output_size " << encoder_output_size_; + LOG(INFO) << "\tnum_blocks " << num_blocks_; + LOG(INFO) << "\thead " << head_; + LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_; + LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_; + LOG(INFO) << "\tright_context " << right_context_; + LOG(INFO) << "\tsos " << sos_; + LOG(INFO) << "\teos " << eos_; + LOG(INFO) << "\tis bidirectional decoder " << is_bidecoder_; + LOG(INFO) << "\tchunk_size " << chunk_size_; + LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_; + + // 3. Read model nodes + LOG(INFO) << "Onnx Encoder:"; + GetInputOutputInfo(encoder_, &encoder_in_names_, &encoder_out_names_); + LOG(INFO) << "Onnx CTC:"; + GetInputOutputInfo(ctc_, &ctc_in_names_, &ctc_out_names_); + LOG(INFO) << "Onnx Rescore:"; + GetInputOutputInfo(rescore_, &rescore_in_names_, &rescore_out_names_); +} + +U2OnnxNnet::U2OnnxNnet(const ModelOptions& opts) : opts_(opts) { + LoadModel(opts_.model_path); +} + +// shallow copy +U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) { + // metadatas + encoder_output_size_ = other.encoder_output_size_; + num_blocks_ = other.num_blocks_; + head_ = other.head_; + cnn_module_kernel_ = other.cnn_module_kernel_; + right_context_ = other.right_context_; + subsampling_rate_ = other.subsampling_rate_; + sos_ = other.sos_; + eos_ = other.eos_; + is_bidecoder_ = other.is_bidecoder_; + chunk_size_ = other.chunk_size_; + num_left_chunks_ = other.num_left_chunks_; + offset_ = other.offset_; + + // session + encoder_ = other.encoder_; + ctc_ = other.ctc_; + rescore_ = other.rescore_; + + // node names + encoder_in_names_ = other.encoder_in_names_; + encoder_out_names_ = other.encoder_out_names_; + ctc_in_names_ = other.ctc_in_names_; + ctc_out_names_ = other.ctc_out_names_; + rescore_in_names_ = other.rescore_in_names_; + rescore_out_names_ = other.rescore_out_names_; +} + +void U2OnnxNnet::GetInputOutputInfo(const std::shared_ptr& runtime, + std::vector* in_names, std::vector* out_names) { + std::vector inputs_info = runtime->GetInputInfos(); + (*in_names).resize(inputs_info.size()); + for (int i = 0; i < inputs_info.size(); ++i){ + fastdeploy::TensorInfo info = inputs_info[i]; + + std::stringstream shape; + for(int j = 0; j < info.shape.size(); ++j){ + shape << info.shape[j]; + shape << " "; + } + LOG(INFO) << "\tInput " << i << " : name=" << info.name << " type=" << info.dtype + << " dims=" << shape.str(); + (*in_names)[i] = info.name; + } + std::vector outputs_info = runtime->GetOutputInfos(); + (*out_names).resize(outputs_info.size()); + for (int i = 0; i < outputs_info.size(); ++i){ + fastdeploy::TensorInfo info = outputs_info[i]; + + std::stringstream shape; + for(int j = 0; j < info.shape.size(); ++j){ + shape << info.shape[j]; + shape << " "; + } + LOG(INFO) << "\tOutput " << i << " : name=" << info.name << " type=" << info.dtype + << " dims=" << shape.str(); + (*out_names)[i] = info.name; + } +} + +std::shared_ptr U2OnnxNnet::Clone() const { + auto asr_model = std::make_shared(*this); + // reset inner state for new decoding + asr_model->Reset(); + return asr_model; +} + +void U2OnnxNnet::Reset() { + offset_ = 0; + encoder_outs_.clear(); + cached_feats_.clear(); + // Reset att_cache + if (num_left_chunks_ > 0) { + int required_cache_size = chunk_size_ * num_left_chunks_; + offset_ = required_cache_size; + att_cache_.resize(num_blocks_ * head_ * required_cache_size * + encoder_output_size_ / head_ * 2, + 0.0); + const std::vector att_cache_shape = {num_blocks_, head_, required_cache_size, + encoder_output_size_ / head_ * 2}; + att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data()); + } else { + att_cache_.resize(0, 0.0); + const std::vector att_cache_shape = {num_blocks_, head_, 0, + encoder_output_size_ / head_ * 2}; + att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data()); + } + + // Reset cnn_cache + cnn_cache_.resize( + num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0); + const std::vector cnn_cache_shape = {num_blocks_, 1, encoder_output_size_, + cnn_module_kernel_ - 1}; + cnn_cache_ort_.SetExternalData(cnn_cache_shape, fastdeploy::FDDataType::FP32, cnn_cache_.data()); +} + +void U2OnnxNnet::FeedForward(const std::vector& features, + const int32& feature_dim, + NnetOut* out) { + kaldi::Timer timer; + + std::vector ctc_probs; + ForwardEncoderChunkImpl( + features, feature_dim, &out->logprobs, &out->vocab_dim); + VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. " + << features.size() / feature_dim << " frames."; +} + +void U2OnnxNnet::ForwardEncoderChunkImpl( + const std::vector& chunk_feats, + const int32& feat_dim, + std::vector* out_prob, + int32* vocab_dim) { + + // 1. Prepare onnx required data, splice cached_feature_ and chunk_feats + // chunk + int num_frames = chunk_feats.size() / feat_dim; + VLOG(3) << "num_frames: " << num_frames; + VLOG(3) << "feat_dim: " << feat_dim; + const int feature_dim = feat_dim; + std::vector feats; + feats.insert(feats.end(), chunk_feats.begin(), chunk_feats.end()); + fastdeploy::FDTensor feats_ort; + const std::vector feats_shape = {1, num_frames, feature_dim}; + feats_ort.SetExternalData(feats_shape, fastdeploy::FDDataType::FP32, feats.data()); + + // offset + int64_t offset_int64 = static_cast(offset_); + fastdeploy::FDTensor offset_ort; + offset_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &offset_int64); + + // required_cache_size + int64_t required_cache_size = chunk_size_ * num_left_chunks_; + fastdeploy::FDTensor required_cache_size_ort(""); + required_cache_size_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &required_cache_size); + + // att_mask + fastdeploy::FDTensor att_mask_ort; + std::vector att_mask(required_cache_size + chunk_size_, 1); + if (num_left_chunks_ > 0) { + int chunk_idx = offset_ / chunk_size_ - num_left_chunks_; + if (chunk_idx < num_left_chunks_) { + for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) { + att_mask[i] = 0; + } + } + const std::vector att_mask_shape = {1, 1, required_cache_size + chunk_size_}; + att_mask_ort.SetExternalData(att_mask_shape, fastdeploy::FDDataType::BOOL, reinterpret_cast(att_mask.data())); + } + + // 2. Encoder chunk forward + std::vector inputs(encoder_in_names_.size()); + for (int i = 0; i < encoder_in_names_.size(); ++i) { + std::string name = encoder_in_names_[i]; + if (!strcmp(name.data(), "chunk")) { + inputs[i] = std::move(feats_ort); + inputs[i].name = "chunk"; + } else if (!strcmp(name.data(), "offset")) { + inputs[i] = std::move(offset_ort); + inputs[i].name = "offset"; + } else if (!strcmp(name.data(), "required_cache_size")) { + inputs[i] = std::move(required_cache_size_ort); + inputs[i].name = "required_cache_size"; + } else if (!strcmp(name.data(), "att_cache")) { + inputs[i] = std::move(att_cache_ort_); + inputs[i].name = "att_cache"; + } else if (!strcmp(name.data(), "cnn_cache")) { + inputs[i] = std::move(cnn_cache_ort_); + inputs[i].name = "cnn_cache"; + } else if (!strcmp(name.data(), "att_mask")) { + inputs[i] = std::move(att_mask_ort); + inputs[i].name = "att_mask"; + } + } + + std::vector ort_outputs; + assert(encoder_->Infer(inputs, &ort_outputs)); + + offset_ += static_cast(ort_outputs[0].shape[1]); + att_cache_ort_ = std::move(ort_outputs[1]); + cnn_cache_ort_ = std::move(ort_outputs[2]); + + std::vector ctc_inputs; + ctc_inputs.emplace_back(std::move(ort_outputs[0])); + // ctc_inputs[0] = std::move(ort_outputs[0]); + ctc_inputs[0].name = ctc_in_names_[0]; + + std::vector ctc_ort_outputs; + assert(ctc_->Infer(ctc_inputs, &ctc_ort_outputs)); + encoder_outs_.emplace_back(std::move(ctc_inputs[0])); // ***** + + float* logp_data = reinterpret_cast(ctc_ort_outputs[0].Data()); + + // Copy to output, (B=1,T,D) + std::vector ctc_log_probs_shape = ctc_ort_outputs[0].shape; + CHECK_EQ(ctc_log_probs_shape.size(), 3); + int B = ctc_log_probs_shape[0]; + CHECK_EQ(B, 1); + int T = ctc_log_probs_shape[1]; + int D = ctc_log_probs_shape[2]; + *vocab_dim = D; + + out_prob->resize(T * D); + std::memcpy( + out_prob->data(), logp_data, T * D * sizeof(kaldi::BaseFloat)); + return; +} + +float U2OnnxNnet::ComputeAttentionScore(const float* prob, + const std::vector& hyp, int eos, + int decode_out_len) { + float score = 0.0f; + for (size_t j = 0; j < hyp.size(); ++j) { + score += *(prob + j * decode_out_len + hyp[j]); + } + score += *(prob + hyp.size() * decode_out_len + eos); + return score; +} + +void U2OnnxNnet::AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score) { + CHECK(rescoring_score != nullptr); + int num_hyps = hyps.size(); + rescoring_score->resize(num_hyps, 0.0f); + + if (num_hyps == 0) { + return; + } + // No encoder output + if (encoder_outs_.size() == 0) { + return; + } + + std::vector hyps_lens; + int max_hyps_len = 0; + for (size_t i = 0; i < num_hyps; ++i) { + int length = hyps[i].size() + 1; + max_hyps_len = std::max(length, max_hyps_len); + hyps_lens.emplace_back(static_cast(length)); + } + + std::vector rescore_input; + int encoder_len = 0; + for (int i = 0; i < encoder_outs_.size(); i++) { + float* encoder_outs_data = reinterpret_cast(encoder_outs_[i].Data()); + for (int j = 0; j < encoder_outs_[i].Numel(); j++) { + rescore_input.emplace_back(encoder_outs_data[j]); + } + encoder_len += encoder_outs_[i].shape[1]; + } + + std::vector hyps_pad; + + for (size_t i = 0; i < num_hyps; ++i) { + const std::vector& hyp = hyps[i]; + hyps_pad.emplace_back(sos_); + size_t j = 0; + for (; j < hyp.size(); ++j) { + hyps_pad.emplace_back(hyp[j]); + } + if (j == max_hyps_len - 1) { + continue; + } + for (; j < max_hyps_len - 1; ++j) { + hyps_pad.emplace_back(0); + } + } + + const std::vector hyps_pad_shape = {num_hyps, max_hyps_len}; + const std::vector hyps_lens_shape = {num_hyps}; + const std::vector decode_input_shape = {1, encoder_len, encoder_output_size_}; + + fastdeploy::FDTensor hyps_pad_tensor_; + hyps_pad_tensor_.SetExternalData(hyps_pad_shape, fastdeploy::FDDataType::INT64, hyps_pad.data()); + fastdeploy::FDTensor hyps_lens_tensor_; + hyps_lens_tensor_.SetExternalData(hyps_lens_shape, fastdeploy::FDDataType::INT64, hyps_lens.data()); + fastdeploy::FDTensor decode_input_tensor_; + decode_input_tensor_.SetExternalData(decode_input_shape, fastdeploy::FDDataType::FP32, rescore_input.data()); + + std::vector rescore_inputs(3); + + rescore_inputs[0] = std::move(hyps_pad_tensor_); + rescore_inputs[0].name = rescore_in_names_[0]; + rescore_inputs[1] = std::move(hyps_lens_tensor_); + rescore_inputs[1].name = rescore_in_names_[1]; + rescore_inputs[2] = std::move(decode_input_tensor_); + rescore_inputs[2].name = rescore_in_names_[2]; + + std::vector rescore_outputs; + assert(rescore_->Infer(rescore_inputs, &rescore_outputs)); + + float* decoder_outs_data = reinterpret_cast(rescore_outputs[0].Data()); + float* r_decoder_outs_data = reinterpret_cast(rescore_outputs[1].Data()); + + int decode_out_len = rescore_outputs[0].shape[2]; + + for (size_t i = 0; i < num_hyps; ++i) { + const std::vector& hyp = hyps[i]; + float score = 0.0f; + // left to right decoder score + score = ComputeAttentionScore( + decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_, + decode_out_len); + // Optional: Used for right to left score + float r_score = 0.0f; + if (is_bidecoder_ && reverse_weight > 0) { + std::vector r_hyp(hyp.size()); + std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); + // right to left decoder score + r_score = ComputeAttentionScore( + r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_, + decode_out_len); + } + // combined left-to-right and right-to-left score + (*rescoring_score)[i] = + score * (1 - reverse_weight) + r_score * reverse_weight; + } +} + +void U2OnnxNnet::EncoderOuts( + std::vector>* encoder_out) const { +} + +} //namepace ppspeech \ No newline at end of file diff --git a/runtime/engine/asr/nnet/u2_onnx_nnet.h b/runtime/engine/asr/nnet/u2_onnx_nnet.h new file mode 100644 index 00000000..6e9126b0 --- /dev/null +++ b/runtime/engine/asr/nnet/u2_onnx_nnet.h @@ -0,0 +1,97 @@ +// Copyright 2022 Horizon Robotics. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from +// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h + +#pragma once + +#include "base/common.h" +#include "matrix/kaldi-matrix.h" +#include "nnet/nnet_itf.h" +#include "nnet/u2_nnet.h" + +#include "fastdeploy/runtime.h" + +namespace ppspeech { + +class U2OnnxNnet : public U2NnetBase { + + public: + explicit U2OnnxNnet(const ModelOptions& opts); + U2OnnxNnet(const U2OnnxNnet& other); + + void FeedForward(const std::vector& features, + const int32& feature_dim, + NnetOut* out) override; + + void Reset() override; + + bool IsLogProb() override { return true; } + + void Dim(); + + void LoadModel(const std::string& model_dir); + + std::shared_ptr Clone() const override; + + void ForwardEncoderChunkImpl( + const std::vector& chunk_feats, + const int32& feat_dim, + std::vector* ctc_probs, + int32* vocab_dim) override; + + float ComputeAttentionScore(const float* prob, const std::vector& hyp, + int eos, int decode_out_len); + + void AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score) override; + + void EncoderOuts( + std::vector>* encoder_out) const; + + void GetInputOutputInfo(const std::shared_ptr& runtime, + std::vector* in_names, + std::vector* out_names); + private: + ModelOptions opts_; + + int encoder_output_size_ = 0; + int num_blocks_ = 0; + int cnn_module_kernel_ = 0; + int head_ = 0; + + // sessions + std::shared_ptr encoder_ = nullptr; + std::shared_ptr rescore_ = nullptr; + std::shared_ptr ctc_ = nullptr; + + + // node names + std::vector encoder_in_names_, encoder_out_names_; + std::vector ctc_in_names_, ctc_out_names_; + std::vector rescore_in_names_, rescore_out_names_; + + // caches + fastdeploy::FDTensor att_cache_ort_; + fastdeploy::FDTensor cnn_cache_ort_; + std::vector encoder_outs_; + + std::vector att_cache_; + std::vector cnn_cache_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/CMakeLists.txt b/runtime/engine/asr/recognizer/CMakeLists.txt new file mode 100644 index 00000000..e8c86505 --- /dev/null +++ b/runtime/engine/asr/recognizer/CMakeLists.txt @@ -0,0 +1,26 @@ +set(srcs) + +list(APPEND srcs + recognizer_controller.cc + recognizer_controller_impl.cc + recognizer_instance.cc + recognizer.cc +) + +add_library(recognizer STATIC ${srcs}) +target_link_libraries(recognizer PUBLIC decoder) + +set(TEST_BINS + recognizer_batch_main + recognizer_batch_main2 + recognizer_main +) + +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl) +endforeach() diff --git a/runtime/engine/asr/recognizer/recognizer.cc b/runtime/engine/asr/recognizer/recognizer.cc new file mode 100644 index 00000000..3a95bcc8 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/recognizer.h" +#include "recognizer/recognizer_instance.h" + +bool InitRecognizer(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance) { + return ppspeech::RecognizerInstance::GetInstance().Init(model_file, + word_symbol_table_file, + fst_file, + num_instance); +} + +int GetRecognizerInstanceId() { + return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId(); +} + +void InitDecoder(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id); +} + +void AcceptData(const std::vector& waves, int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id); +} + +void SetInputFinished(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id); +} + +std::string GetFinalResult(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id); +} \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer.h b/runtime/engine/asr/recognizer/recognizer.h new file mode 100644 index 00000000..bd7fb129 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer.h @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +bool InitRecognizer(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance); +int GetRecognizerInstanceId(); +void InitDecoder(int instance_id); +void AcceptData(const std::vector& waves, int instance_id); +void SetInputFinished(int instance_id); +std::string GetFinalResult(int instance_id); \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_batch_main.cc b/runtime/engine/asr/recognizer/recognizer_batch_main.cc new file mode 100644 index 00000000..0cc34f26 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_batch_main.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/base/thread_pool.h" +#include "common/utils/file_utils.h" +#include "common/utils/strings.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/u2_nnet.h" +#include "recognizer/recognizer_controller.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); +DEFINE_int32(njob, 3, "njob"); + +using std::string; +using std::vector; + +void SplitUtt(string wavlist_file, + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; + wavlists->resize(njob); + uttlists->resize(njob); + ppspeech::ReadFileToVector(wavlist_file, &wavlist); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + string utt_str = wavlist[idx]; + vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); + LOG(INFO) << utt_wav[0]; + CHECK_EQ(utt_wav.size(), size_t(2)); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); + } +} + +void recognizer_func(ppspeech::RecognizerController* recognizer_controller, + std::vector wavlist, + std::vector uttlist, + std::vector* results) { + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; + double tot_decode_time = 0.0; + int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; + if (wavlist.empty()) return; + + results->reserve(wavlist.size()); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + std::string utt = uttlist[idx]; + std::string wav_file = wavlist[idx]; + std::ifstream infile; + infile.open(wav_file, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 recog_id = -1; + while (recog_id == -1) { + recog_id = recognizer_controller->GetRecognizerInstanceId(); + } + recognizer_controller->InitDecoder(recog_id); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + recognizer_controller->Accept(wav_chunk, recog_id); + // no overlap + sample_offset += cur_chunk_size; + } + recognizer_controller->SetInputFinished(recog_id); + CHECK(sample_offset == tot_samples); + std::string result = recognizer_controller->GetFinalResult(recog_id); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + result = " "; + } + + tot_decode_time += local_timer.Elapsed(); + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + results->push_back(result); + ++num_done; + } + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int njob = FLAGS_njob; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + ppspeech::RecognizerResource resource = + ppspeech::RecognizerResource::InitFromFlags(); + ppspeech::RecognizerController recognizer_controller(njob, resource); + ThreadPool threadpool(njob); + vector> wavlist; + vector> uttlist; + vector> resultlist(njob); + vector> futurelist; + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + for (size_t i = 0; i < njob; ++i) { + std::future f = threadpool.enqueue(recognizer_func, + &recognizer_controller, + wavlist[i], + uttlist[i], + &resultlist[i]); + futurelist.push_back(std::move(f)); + } + + for (size_t i = 0; i < njob; ++i) { + futurelist[i].get(); + } + + for (size_t idx = 0; idx < njob; ++idx) { + for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) { + string utt = uttlist[idx][utt_idx]; + string result = resultlist[idx][utt_idx]; + result_writer.Write(utt, result); + } + } + return 0; +} diff --git a/runtime/engine/asr/recognizer/recognizer_batch_main2.cc b/runtime/engine/asr/recognizer/recognizer_batch_main2.cc new file mode 100644 index 00000000..fc99bf0b --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_batch_main2.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/base/thread_pool.h" +#include "common/utils/file_utils.h" +#include "common/utils/strings.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/u2_nnet.h" +#include "recognizer/recognizer.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); +DEFINE_int32(njob, 3, "njob"); + +using std::string; +using std::vector; + +void SplitUtt(string wavlist_file, + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; + wavlists->resize(njob); + uttlists->resize(njob); + ppspeech::ReadFileToVector(wavlist_file, &wavlist); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + string utt_str = wavlist[idx]; + vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); + LOG(INFO) << utt_wav[0]; + CHECK_EQ(utt_wav.size(), size_t(2)); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); + } +} + +void recognizer_func(std::vector wavlist, + std::vector uttlist, + std::vector* results) { + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; + double tot_decode_time = 0.0; + int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; + if (wavlist.empty()) return; + + results->reserve(wavlist.size()); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + std::string utt = uttlist[idx]; + std::string wav_file = wavlist[idx]; + std::ifstream infile; + infile.open(wav_file, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 recog_id = -1; + while (recog_id == -1) { + recog_id = GetRecognizerInstanceId(); + } + InitDecoder(recog_id); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + AcceptData(wav_chunk, recog_id); + // no overlap + sample_offset += cur_chunk_size; + } + SetInputFinished(recog_id); + CHECK(sample_offset == tot_samples); + std::string result = GetFinalResult(recog_id); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + result = " "; + } + + tot_decode_time += local_timer.Elapsed(); + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + results->push_back(result); + ++num_done; + } + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int njob = FLAGS_njob; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob); + ThreadPool threadpool(njob); + vector> wavlist; + vector> uttlist; + vector> resultlist(njob); + vector> futurelist; + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + for (size_t i = 0; i < njob; ++i) { + std::future f = threadpool.enqueue(recognizer_func, + wavlist[i], + uttlist[i], + &resultlist[i]); + futurelist.push_back(std::move(f)); + } + + for (size_t i = 0; i < njob; ++i) { + futurelist[i].get(); + } + + for (size_t idx = 0; idx < njob; ++idx) { + for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) { + string utt = uttlist[idx][utt_idx]; + string result = resultlist[idx][utt_idx]; + result_writer.Write(utt, result); + } + } + return 0; +} diff --git a/runtime/engine/asr/recognizer/recognizer_controller.cc b/runtime/engine/asr/recognizer/recognizer_controller.cc new file mode 100644 index 00000000..ef549263 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_controller.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/recognizer_controller.h" +#include "nnet/u2_nnet.h" + +namespace ppspeech { + +RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) { + recognizer_workers.resize(num_worker); + for (size_t i = 0; i < num_worker; ++i) { + recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource)); + waiting_workers.push(i); + } +} + +int RecognizerController::GetRecognizerInstanceId() { + if (waiting_workers.empty()) { + return -1; + } + int idx = -1; + { + std::unique_lock lock(mutex_); + idx = waiting_workers.front(); + waiting_workers.pop(); + } + return idx; +} + +RecognizerController::~RecognizerController() { + for (size_t i = 0; i < recognizer_workers.size(); ++i) { + recognizer_workers[i]->WaitFinished(); + } +} + +void RecognizerController::InitDecoder(int idx) { + recognizer_workers[idx]->InitDecoder(); +} + +std::string RecognizerController::GetFinalResult(int idx) { + recognizer_workers[idx]->WaitDecoderFinished(); + recognizer_workers[idx]->AttentionRescoring(); + std::string result = recognizer_workers[idx]->GetFinalResult(); + { + std::unique_lock lock(mutex_); + waiting_workers.push(idx); + } + return result; +} + +void RecognizerController::Accept(std::vector data, int idx) { + recognizer_workers[idx]->Accept(data); +} + +void RecognizerController::SetInputFinished(int idx) { + recognizer_workers[idx]->SetInputFinished(); +} + +} diff --git a/runtime/engine/asr/recognizer/recognizer_controller.h b/runtime/engine/asr/recognizer/recognizer_controller.h new file mode 100644 index 00000000..16a8dd13 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_controller.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "recognizer/recognizer_controller_impl.h" + +namespace ppspeech { + +class RecognizerController { + public: + explicit RecognizerController(int num_worker, RecognizerResource resource); + ~RecognizerController(); + int GetRecognizerInstanceId(); + void InitDecoder(int idx); + void Accept(std::vector data, int idx); + void SetInputFinished(int idx); + std::string GetFinalResult(int idx); + + private: + std::queue waiting_workers; + std::mutex mutex_; + std::vector> recognizer_workers; + + DISALLOW_COPY_AND_ASSIGN(RecognizerController); +}; + +} \ No newline at end of file diff --git a/speechx/speechx/recognizer/u2_recognizer.cc b/runtime/engine/asr/recognizer/recognizer_controller_impl.cc similarity index 57% rename from speechx/speechx/recognizer/u2_recognizer.cc rename to runtime/engine/asr/recognizer/recognizer_controller_impl.cc index d1d308eb..cc4d3c78 100644 --- a/speechx/speechx/recognizer/u2_recognizer.cc +++ b/runtime/engine/asr/recognizer/recognizer_controller_impl.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,86 +12,180 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "recognizer/u2_recognizer.h" - -#include "nnet/u2_nnet.h" +#include "recognizer/recognizer_controller_impl.h" +#include "decoder/ctc_prefix_beam_search_decoder.h" +#include "common/utils/strings.h" namespace ppspeech { -using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; -using std::unique_ptr; -using std::vector; - -U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) - : opts_(resource) { +RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource) +: opts_(resource) { + BaseFloat am_scale = resource.acoustic_scale; + BaseFloat blank_threshold = resource.blank_threshold; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - feature_pipeline_.reset(new FeaturePipeline(feature_opts)); + std::shared_ptr feature_pipeline( + new FeaturePipeline(feature_opts)); + std::shared_ptr nnet; +#ifndef USE_ONNX + nnet = resource.nnet->Clone(); +#else + if (resource.model_opts.with_onnx_model){ + nnet.reset(new U2OnnxNnet(resource.model_opts)); + } else { + nnet = resource.nnet->Clone(); + } +#endif + nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold)); + nnet_thread_ = std::thread(RunNnetEvaluation, this); + + decodable_.reset(new Decodable(nnet_producer_, am_scale)); + if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) { + LOG(INFO) << "Init PrefixBeamSearch Decoder"; + decoder_ = std::make_unique( + resource.decoder_opts.ctc_prefix_search_opts); + } else { + LOG(INFO) << "Init TLGDecoder"; + decoder_ = std::make_unique( + resource.decoder_opts.tlg_decoder_opts); + } - std::shared_ptr nnet(new U2Nnet(resource.model_opts)); + symbol_table_ = decoder_->WordSymbolTable(); + global_frame_offset_ = 0; + input_finished_ = false; + num_frames_ = 0; + result_.clear(); +} - BaseFloat am_scale = resource.acoustic_scale; - decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); +RecognizerControllerImpl::~RecognizerControllerImpl() { + WaitFinished(); +} - CHECK_NE(resource.vocab_path, ""); - decoder_.reset(new CTCPrefixBeamSearch( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); +void RecognizerControllerImpl::Reset() { + nnet_producer_->Reset(); +} - unit_table_ = decoder_->VocabTable(); - symbol_table_ = unit_table_; +void RecognizerControllerImpl::RunDecoder(RecognizerControllerImpl* me) { + me->RunDecoderInternal(); +} - input_finished_ = false; +void RecognizerControllerImpl::RunDecoderInternal() { + LOG(INFO) << "DecoderInternal begin"; + while (!nnet_producer_->IsFinished()) { + nnet_condition_.notify_one(); + decoder_->AdvanceDecode(decodable_); + } + decoder_->AdvanceDecode(decodable_); + UpdateResult(false); + LOG(INFO) << "DecoderInternal exit"; +} - Reset(); +void RecognizerControllerImpl::WaitDecoderFinished() { + if (decoder_thread_.joinable()) decoder_thread_.join(); } -void U2Recognizer::Reset() { - global_frame_offset_ = 0; - num_frames_ = 0; - result_.clear(); +void RecognizerControllerImpl::RunNnetEvaluation(RecognizerControllerImpl* me) { + me->RunNnetEvaluationInternal(); +} - decodable_->Reset(); - decoder_->Reset(); +void RecognizerControllerImpl::SetInputFinished() { + nnet_producer_->SetInputFinished(); + nnet_condition_.notify_one(); + LOG(INFO) << "Set Input Finished"; } -void U2Recognizer::ResetContinuousDecoding() { - global_frame_offset_ = num_frames_; +void RecognizerControllerImpl::WaitFinished() { + abort_ = true; + LOG(INFO) << "nnet wait finished"; + nnet_condition_.notify_one(); + if (nnet_thread_.joinable()) { + nnet_thread_.join(); + } +} + +void RecognizerControllerImpl::RunNnetEvaluationInternal() { + bool result = false; + LOG(INFO) << "NnetEvaluationInteral begin"; + while (!abort_) { + std::unique_lock lock(nnet_mutex_); + nnet_condition_.wait(lock); + do { + result = nnet_producer_->Compute(); + decoder_condition_.notify_one(); + } while (result); + } + LOG(INFO) << "NnetEvaluationInteral exit"; +} + +void RecognizerControllerImpl::Accept(std::vector data) { + nnet_producer_->Accept(data); + nnet_condition_.notify_one(); +} + +void RecognizerControllerImpl::InitDecoder() { + global_frame_offset_ = 0; + input_finished_ = false; num_frames_ = 0; result_.clear(); decodable_->Reset(); decoder_->Reset(); + decoder_thread_ = std::thread(RunDecoder, this); } +void RecognizerControllerImpl::AttentionRescoring() { + decoder_->FinalizeSearch(); + UpdateResult(false); -void U2Recognizer::Accept(const VectorBase& waves) { - kaldi::Timer timer; - feature_pipeline_->Accept(waves); - VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim() - << " samples."; -} + // No need to do rescoring + if (0.0 == opts_.decoder_opts.rescoring_weight) { + LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!"; + return; + } + LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!"; + // Inputs() returns N-best input ids, which is the basic unit for rescoring + // In CtcPrefixBeamSearch, inputs are the same to outputs + const auto& hypotheses = decoder_->Inputs(); + int num_hyps = hypotheses.size(); + if (num_hyps <= 0) { + return; + } -void U2Recognizer::Decode() { - decoder_->AdvanceDecode(decodable_); - UpdateResult(false); -} + std::vector rescoring_score; + decodable_->AttentionRescoring( + hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score); -void U2Recognizer::Rescoring() { - // Do attention Rescoring - AttentionRescoring(); + // combine ctc score and rescoring score + for (size_t i = 0; i < num_hyps; i++) { + VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i] + << " ctc_score: " << result_[i].score + << " rescoring_weight: " << opts_.decoder_opts.rescoring_weight + << " ctc_weight: " << opts_.decoder_opts.ctc_weight; + result_[i].score = + opts_.decoder_opts.rescoring_weight * rescoring_score[i] + + opts_.decoder_opts.ctc_weight * result_[i].score; + + VLOG(3) << "hyp: " << result_[0].sentence + << " score: " << result_[0].score; + } + + std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); + VLOG(3) << "result: " << result_[0].sentence + << " score: " << result_[0].score; } -void U2Recognizer::UpdateResult(bool finish) { +std::string RecognizerControllerImpl::GetFinalResult() { return result_[0].sentence; } + +std::string RecognizerControllerImpl::GetPartialResult() { return result_[0].sentence; } + +void RecognizerControllerImpl::UpdateResult(bool finish) { const auto& hypotheses = decoder_->Outputs(); const auto& inputs = decoder_->Inputs(); const auto& likelihood = decoder_->Likelihood(); const auto& times = decoder_->Times(); result_.clear(); - CHECK_EQ(hypotheses.size(), likelihood.size()); + CHECK_EQ(inputs.size(), likelihood.size()); for (size_t i = 0; i < hypotheses.size(); i++) { const std::vector& hypothesis = hypotheses[i]; @@ -99,21 +193,16 @@ void U2Recognizer::UpdateResult(bool finish) { path.score = likelihood[i]; for (size_t j = 0; j < hypothesis.size(); j++) { std::string word = symbol_table_->Find(hypothesis[j]); - // A detailed explanation of this if-else branch can be found in - // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 - if (decoder_->Type() == kWfstBeamSearch) { - path.sentence += (" " + word); - } else { - path.sentence += (word); - } + path.sentence += (" " + word); } + path.sentence = DelBlank(path.sentence); // TimeStamp is only supported in final result // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to // various FST operations when building the decoding graph. So here we // use time stamp of the input(e2e model unit), which is more accurate, // and it requires the symbol table of the e2e model used in training. - if (unit_table_ != nullptr && finish) { + if (symbol_table_ != nullptr && finish) { int offset = global_frame_offset_ * FrameShiftInMs(); const std::vector& input = inputs[i]; @@ -121,7 +210,7 @@ void U2Recognizer::UpdateResult(bool finish) { CHECK_EQ(input.size(), time_stamp.size()); for (size_t j = 0; j < input.size(); j++) { - std::string word = unit_table_->Find(input[j]); + std::string word = symbol_table_->Find(input[j]); int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 @@ -163,56 +252,4 @@ void U2Recognizer::UpdateResult(bool finish) { } } -void U2Recognizer::AttentionRescoring() { - decoder_->FinalizeSearch(); - UpdateResult(true); - - // No need to do rescoring - if (0.0 == opts_.decoder_opts.rescoring_weight) { - LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!"; - return; - } - LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!"; - - // Inputs() returns N-best input ids, which is the basic unit for rescoring - // In CtcPrefixBeamSearch, inputs are the same to outputs - const auto& hypotheses = decoder_->Inputs(); - int num_hyps = hypotheses.size(); - if (num_hyps <= 0) { - return; - } - - std::vector rescoring_score; - decodable_->AttentionRescoring( - hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score); - - // combine ctc score and rescoring score - for (size_t i = 0; i < num_hyps; i++) { - VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i] - << " ctc_score: " << result_[i].score - << " rescoring_weight: " << opts_.decoder_opts.rescoring_weight - << " ctc_weight: " << opts_.decoder_opts.ctc_weight; - result_[i].score = - opts_.decoder_opts.rescoring_weight * rescoring_score[i] + - opts_.decoder_opts.ctc_weight * result_[i].score; - - VLOG(3) << "hyp: " << result_[0].sentence - << " score: " << result_[0].score; - } - - std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); - VLOG(3) << "result: " << result_[0].sentence - << " score: " << result_[0].score; -} - -std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } - -std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } - -void U2Recognizer::SetFinished() { - feature_pipeline_->SetFinished(); - input_finished_ = true; -} - - } // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_controller_impl.h b/runtime/engine/asr/recognizer/recognizer_controller_impl.h new file mode 100644 index 00000000..3ff6faa6 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_controller_impl.h @@ -0,0 +1,89 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "decoder/common.h" +#include "fst/fstlib.h" +#include "fst/symbol-table.h" +#include "nnet/u2_nnet.h" +#include "nnet/nnet_producer.h" +#ifdef USE_ONNX +#include "nnet/u2_onnx_nnet.h" +#endif +#include "nnet/decodable.h" +#include "recognizer/recognizer_resource.h" + +#include + +namespace ppspeech { + +class RecognizerControllerImpl { + public: + explicit RecognizerControllerImpl(const RecognizerResource& resource); + ~RecognizerControllerImpl(); + void Accept(std::vector data); + void InitDecoder(); + void SetInputFinished(); + std::string GetFinalResult(); + std::string GetPartialResult(); + void Rescoring(); + void Reset(); + void WaitDecoderFinished(); + void WaitFinished(); + void AttentionRescoring(); + bool DecodedSomething() const { + return !result_.empty() && !result_[0].sentence.empty(); + } + int FrameShiftInMs() const { + return 1; //todo + } + + private: + + static void RunNnetEvaluation(RecognizerControllerImpl* me); + void RunNnetEvaluationInternal(); + static void RunDecoder(RecognizerControllerImpl* me); + void RunDecoderInternal(); + void UpdateResult(bool finish = false); + + std::shared_ptr decodable_; + std::unique_ptr decoder_; + std::shared_ptr nnet_producer_; + + // e2e unit symbol table + std::shared_ptr symbol_table_ = nullptr; + std::vector result_; + + RecognizerResource opts_; + bool abort_ = false; + // global decoded frame offset + int global_frame_offset_; + // cur decoded frame num + int num_frames_; + // timestamp gap between words in a sentence + const int time_stamp_gap_ = 100; + bool input_finished_; + + std::mutex nnet_mutex_; + std::mutex decoder_mutex_; + std::condition_variable nnet_condition_; + std::condition_variable decoder_condition_; + std::thread nnet_thread_; + std::thread decoder_thread_; + + DISALLOW_COPY_AND_ASSIGN(RecognizerControllerImpl); +}; + +} diff --git a/runtime/engine/asr/recognizer/recognizer_instance.cc b/runtime/engine/asr/recognizer/recognizer_instance.cc new file mode 100644 index 00000000..b9019ec4 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_instance.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/recognizer_instance.h" + + +namespace ppspeech { + +RecognizerInstance& RecognizerInstance::GetInstance() { + static RecognizerInstance instance; + return instance; +} + +bool RecognizerInstance::Init(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance) { + RecognizerResource resource = RecognizerResource::InitFromFlags(); + resource.model_opts.model_path = model_file; + //resource.vocab_path = word_symbol_table_file; + if (!fst_file.empty()) { + resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file; + resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file; + } else { + resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table = + word_symbol_table_file; + } + recognizer_controller_ = std::make_unique(num_instance, resource); + return true; +} + +void RecognizerInstance::InitDecoder(int idx) { + recognizer_controller_->InitDecoder(idx); + return; +} + +int RecognizerInstance::GetRecognizerInstanceId() { + return recognizer_controller_->GetRecognizerInstanceId(); +} + +void RecognizerInstance::Accept(const std::vector& waves, int idx) const { + recognizer_controller_->Accept(waves, idx); + return; +} + +void RecognizerInstance::SetInputFinished(int idx) const { + recognizer_controller_->SetInputFinished(idx); + return; +} + +std::string RecognizerInstance::GetResult(int idx) const { + return recognizer_controller_->GetFinalResult(idx); +} + +} \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_instance.h b/runtime/engine/asr/recognizer/recognizer_instance.h new file mode 100644 index 00000000..ef8f524d --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_instance.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "recognizer/recognizer_controller.h" + +namespace ppspeech { + +class RecognizerInstance { + public: + static RecognizerInstance& GetInstance(); + RecognizerInstance() {} + ~RecognizerInstance() {} + bool Init(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance); + int GetRecognizerInstanceId(); + void InitDecoder(int idx); + void Accept(const std::vector& waves, int idx) const; + void SetInputFinished(int idx) const; + std::string GetResult(int idx) const; + + private: + std::unique_ptr recognizer_controller_; +}; + + +} // namespace ppspeech diff --git a/speechx/speechx/recognizer/u2_recognizer_main.cc b/runtime/engine/asr/recognizer/recognizer_main.cc similarity index 75% rename from speechx/speechx/recognizer/u2_recognizer_main.cc rename to runtime/engine/asr/recognizer/recognizer_main.cc index d7c58407..99b7b4dd 100644 --- a/speechx/speechx/recognizer/u2_recognizer_main.cc +++ b/runtime/engine/asr/recognizer/recognizer_main.cc @@ -13,9 +13,9 @@ // limitations under the License. #include "decoder/param.h" -#include "kaldi/feat/wave-reader.h" +#include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" -#include "recognizer/u2_recognizer.h" +#include "recognizer/recognizer_controller.h" DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); @@ -31,6 +31,7 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; double tot_decode_time = 0.0; kaldi::SequentialTableReader wav_reader( @@ -44,11 +45,13 @@ int main(int argc, char* argv[]) { LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - ppspeech::U2RecognizerResource resource = - ppspeech::U2RecognizerResource::InitFromFlags(); - ppspeech::U2Recognizer recognizer(resource); + ppspeech::RecognizerResource resource = + ppspeech::RecognizerResource::InitFromFlags(); + std::shared_ptr recognizer_ptr( + new ppspeech::RecognizerControllerImpl(resource)); for (; !wav_reader.Done(); wav_reader.Next()) { + recognizer_ptr->InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -63,45 +66,32 @@ int main(int argc, char* argv[]) { LOG(INFO) << "wav len (sample): " << tot_samples; int sample_offset = 0; - int cnt = 0; - kaldi::Timer timer; kaldi::Timer local_timer; while (sample_offset < tot_samples) { int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset); - kaldi::Vector wav_chunk(cur_chunk_size); + std::vector wav_chunk(cur_chunk_size); for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); + wav_chunk[i] = waveform(sample_offset + i); } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); - recognizer.Accept(wav_chunk); - if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); - } - recognizer.Decode(); - if (recognizer.DecodedSomething()) { - LOG(INFO) << "Pratial result: " << cnt << " " - << recognizer.GetPartialResult(); - } + recognizer_ptr->Accept(wav_chunk); // no overlap sample_offset += cur_chunk_size; - cnt++; } CHECK(sample_offset == tot_samples); + recognizer_ptr->SetInputFinished(); + recognizer_ptr->WaitDecoderFinished(); - // second pass decoding - recognizer.Rescoring(); - - tot_decode_time += timer.Elapsed(); - - std::string result = recognizer.GetFinalResult(); - - recognizer.Reset(); + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + float rescore_time = timer.Elapsed(); + tot_attention_rescore_time += rescore_time; + std::string result = recognizer_ptr->GetFinalResult(); if (result.empty()) { // the TokenWriter can not write empty string. ++num_err; @@ -109,17 +99,20 @@ int main(int argc, char* argv[]) { continue; } + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur - << " cost: " << local_timer.Elapsed(); + << " cost: " << local_timer.Elapsed() << " rescore:" << rescore_time; result_writer.Write(utt, result); ++num_done; } + recognizer_ptr->WaitFinished(); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; } diff --git a/speechx/speechx/recognizer/u2_recognizer.h b/runtime/engine/asr/recognizer/recognizer_resource.h similarity index 54% rename from speechx/speechx/recognizer/u2_recognizer.h rename to runtime/engine/asr/recognizer/recognizer_resource.h index 25850863..064a5b5b 100644 --- a/speechx/speechx/recognizer/u2_recognizer.h +++ b/runtime/engine/asr/recognizer/recognizer_resource.h @@ -1,27 +1,8 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #pragma once -#include "decoder/common.h" #include "decoder/ctc_beam_search_opt.h" -#include "decoder/ctc_prefix_beam_search_decoder.h" -#include "decoder/decoder_itf.h" -#include "frontend/audio/feature_pipeline.h" -#include "fst/fstlib.h" -#include "fst/symbol-table.h" -#include "nnet/decodable.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/feature_pipeline.h" DECLARE_int32(nnet_decoder_chunk); DECLARE_int32(num_left_chunks); @@ -30,9 +11,9 @@ DECLARE_double(rescoring_weight); DECLARE_double(reverse_weight); DECLARE_int32(nbest); DECLARE_int32(blank); - DECLARE_double(acoustic_scale); -DECLARE_string(vocab_path); +DECLARE_double(blank_threshold); +DECLARE_string(word_symbol_table); namespace ppspeech { @@ -59,6 +40,7 @@ struct DecodeOptions { // CtcEndpointConfig ctc_endpoint_opts; CTCBeamSearchOptions ctc_prefix_search_opts{}; + TLGDecoderOptions tlg_decoder_opts{}; static DecodeOptions InitFromFlags() { DecodeOptions decoder_opts; @@ -70,6 +52,11 @@ struct DecodeOptions { decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank; decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; + decoder_opts.ctc_prefix_search_opts.word_symbol_table = + FLAGS_word_symbol_table; + decoder_opts.tlg_decoder_opts = + ppspeech::TLGDecoderOptions::InitFromFlags(); + LOG(INFO) << "chunk_size: " << decoder_opts.chunk_size; LOG(INFO) << "num_left_chunks: " << decoder_opts.num_left_chunks; LOG(INFO) << "ctc_weight: " << decoder_opts.ctc_weight; @@ -82,19 +69,20 @@ struct DecodeOptions { } }; -struct U2RecognizerResource { +struct RecognizerResource { + // decodable opt kaldi::BaseFloat acoustic_scale{1.0}; - std::string vocab_path{}; + kaldi::BaseFloat blank_threshold{0.98}; FeaturePipelineOptions feature_pipeline_opts{}; ModelOptions model_opts{}; DecodeOptions decoder_opts{}; + std::shared_ptr nnet; - static U2RecognizerResource InitFromFlags() { - U2RecognizerResource resource; - resource.vocab_path = FLAGS_vocab_path; + static RecognizerResource InitFromFlags() { + RecognizerResource resource; resource.acoustic_scale = FLAGS_acoustic_scale; - LOG(INFO) << "vocab path: " << resource.vocab_path; + resource.blank_threshold = FLAGS_blank_threshold; LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale; resource.feature_pipeline_opts = @@ -104,69 +92,17 @@ struct U2RecognizerResource { << resource.feature_pipeline_opts.assembler_opts.fill_zero; resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags(); + #ifndef USE_ONNX + resource.nnet.reset(new U2Nnet(resource.model_opts)); + #else + if (resource.model_opts.with_onnx_model){ + resource.nnet.reset(new U2OnnxNnet(resource.model_opts)); + } else { + resource.nnet.reset(new U2Nnet(resource.model_opts)); + } + #endif return resource; } }; - -class U2Recognizer { - public: - explicit U2Recognizer(const U2RecognizerResource& resouce); - void Reset(); - void ResetContinuousDecoding(); - - void Accept(const kaldi::VectorBase& waves); - void Decode(); - void Rescoring(); - - - std::string GetFinalResult(); - std::string GetPartialResult(); - - void SetFinished(); - bool IsFinished() { return input_finished_; } - - bool DecodedSomething() const { - return !result_.empty() && !result_[0].sentence.empty(); - } - - - int FrameShiftInMs() const { - // one decoder frame length in ms - return decodable_->Nnet()->SubsamplingRate() * - feature_pipeline_->FrameShift(); - } - - - const std::vector& Result() const { return result_; } - - private: - void AttentionRescoring(); - void UpdateResult(bool finish = false); - - private: - U2RecognizerResource opts_; - - // std::shared_ptr resource_; - // U2RecognizerResource resource_; - std::shared_ptr feature_pipeline_; - std::shared_ptr decodable_; - std::unique_ptr decoder_; - - // e2e unit symbol table - std::shared_ptr unit_table_ = nullptr; - std::shared_ptr symbol_table_ = nullptr; - - std::vector result_; - - // global decoded frame offset - int global_frame_offset_; - // cur decoded frame num - int num_frames_; - // timestamp gap between words in a sentence - const int time_stamp_gap_ = 100; - - bool input_finished_; -}; - -} // namespace ppspeech \ No newline at end of file +} //namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/asr/server/CMakeLists.txt b/runtime/engine/asr/server/CMakeLists.txt new file mode 100644 index 00000000..566b42ee --- /dev/null +++ b/runtime/engine/asr/server/CMakeLists.txt @@ -0,0 +1 @@ +#add_subdirectory(websocket) diff --git a/speechx/speechx/protocol/websocket/CMakeLists.txt b/runtime/engine/asr/server/websocket/CMakeLists.txt similarity index 98% rename from speechx/speechx/protocol/websocket/CMakeLists.txt rename to runtime/engine/asr/server/websocket/CMakeLists.txt index cafbbec7..9991e47b 100644 --- a/speechx/speechx/protocol/websocket/CMakeLists.txt +++ b/runtime/engine/asr/server/websocket/CMakeLists.txt @@ -10,4 +10,4 @@ target_link_libraries(websocket_server_main PUBLIC fst websocket ${DEPS}) add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS}) +target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS}) \ No newline at end of file diff --git a/speechx/speechx/protocol/websocket/websocket_client.cc b/runtime/engine/asr/server/websocket/websocket_client.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client.cc rename to runtime/engine/asr/server/websocket/websocket_client.cc diff --git a/speechx/speechx/protocol/websocket/websocket_client.h b/runtime/engine/asr/server/websocket/websocket_client.h similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client.h rename to runtime/engine/asr/server/websocket/websocket_client.h diff --git a/speechx/speechx/protocol/websocket/websocket_client_main.cc b/runtime/engine/asr/server/websocket/websocket_client_main.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client_main.cc rename to runtime/engine/asr/server/websocket/websocket_client_main.cc diff --git a/speechx/speechx/protocol/websocket/websocket_server.cc b/runtime/engine/asr/server/websocket/websocket_server.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_server.cc rename to runtime/engine/asr/server/websocket/websocket_server.cc diff --git a/speechx/speechx/protocol/websocket/websocket_server.h b/runtime/engine/asr/server/websocket/websocket_server.h similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_server.h rename to runtime/engine/asr/server/websocket/websocket_server.h diff --git a/speechx/speechx/protocol/websocket/websocket_server_main.cc b/runtime/engine/asr/server/websocket/websocket_server_main.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_server_main.cc rename to runtime/engine/asr/server/websocket/websocket_server_main.cc diff --git a/runtime/engine/audio_classification/CMakeLists.txt b/runtime/engine/audio_classification/CMakeLists.txt new file mode 100644 index 00000000..52f1efef --- /dev/null +++ b/runtime/engine/audio_classification/CMakeLists.txt @@ -0,0 +1,3 @@ +# add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND") +add_definitions("-DUSE_ORT_BACKEND") +add_subdirectory(nnet) \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/CMakeLists.txt b/runtime/engine/audio_classification/nnet/CMakeLists.txt new file mode 100644 index 00000000..bb7f8eec --- /dev/null +++ b/runtime/engine/audio_classification/nnet/CMakeLists.txt @@ -0,0 +1,11 @@ +set(srcs + panns_nnet.cc + panns_interface.cc +) + +add_library(cls SHARED ${srcs}) +target_link_libraries(cls PRIVATE ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils ) + +set(bin_name panns_nnet_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} gflags glog cls) diff --git a/runtime/engine/audio_classification/nnet/panns_interface.cc b/runtime/engine/audio_classification/nnet/panns_interface.cc new file mode 100644 index 00000000..d8b6a8b6 --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_interface.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "audio_classification/nnet/panns_interface.h" + +#include "audio_classification/nnet/panns_nnet.h" +#include "common/base/config.h" + +namespace ppspeech { + +void* ClsCreateInstance(const char* conf_path) { + Config conf(conf_path); + // cls init + ppspeech::ClsNnetConf cls_nnet_conf; + cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true); + cls_nnet_conf.wav_normal_type_ = + conf.Read("wav_normal_type", std::string("linear")); + cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0); + cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string("")); + cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string("")); + cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string("")); + cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12); + cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000); + cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32); + cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10); + cls_nnet_conf.num_bins = conf.Read("num_bins", 64); + cls_nnet_conf.low_freq = conf.Read("low_freq", 50); + cls_nnet_conf.high_freq = conf.Read("high_freq", 14000); + cls_nnet_conf.dither = conf.Read("dither", 0.0); + + ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet(); + int ret = cls_model->Init(cls_nnet_conf); + return static_cast(cls_model); +} + +int ClsDestroyInstance(void* instance) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model != NULL) { + delete cls_model; + cls_model = NULL; + } + return 0; +} + +int ClsFeedForward(void* instance, + const char* wav_path, + int topk, + char* result, + int result_max_len) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model == NULL) { + printf("instance is null\n"); + return -1; + } + int ret = cls_model->Forward(wav_path, topk, result, result_max_len); + return 0; +} + +int ClsReset(void* instance) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model == NULL) { + printf("instance is null\n"); + return -1; + } + cls_model->Reset(); + return 0; +} +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/panns_interface.h b/runtime/engine/audio_classification/nnet/panns_interface.h new file mode 100644 index 00000000..0d1ce95f --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_interface.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ppspeech { + +void* ClsCreateInstance(const char* conf_path); +int ClsDestroyInstance(void* instance); +int ClsFeedForward(void* instance, + const char* wav_path, + int topk, + char* result, + int result_max_len); +int ClsReset(void* instance); +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/panns_nnet.cc b/runtime/engine/audio_classification/nnet/panns_nnet.cc new file mode 100644 index 00000000..37ba74f9 --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_nnet.cc @@ -0,0 +1,227 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "audio_classification/nnet/panns_nnet.h" +#ifdef WITH_PROFILING +#include "kaldi/base/timer.h" +#endif + +namespace ppspeech { + +ClsNnet::ClsNnet() { + // wav_reader_ = NULL; + runtime_ = NULL; +} + +void ClsNnet::Reset() { + // wav_reader_->Clear(); + ss_.str(""); +} + +int ClsNnet::Init(const ClsNnetConf& conf) { + conf_ = conf; + // init fbank opts + fbank_opts_.frame_opts.samp_freq = conf.samp_freq; + fbank_opts_.frame_opts.frame_length_ms = conf.frame_length_ms; + fbank_opts_.frame_opts.frame_shift_ms = conf.frame_shift_ms; + fbank_opts_.mel_opts.num_bins = conf.num_bins; + fbank_opts_.mel_opts.low_freq = conf.low_freq; + fbank_opts_.mel_opts.high_freq = conf.high_freq; + fbank_opts_.frame_opts.dither = conf.dither; + fbank_opts_.use_log_fbank = false; + + // init dict + if (conf.dict_file_path_ != "") { + ReadFileToVector(conf.dict_file_path_, &dict_); + } + + // init model + fastdeploy::RuntimeOption runtime_option; + +#ifdef USE_PADDLE_INFERENCE_BACKEND + runtime_option.SetModelPath(conf.model_file_path_, + conf.param_file_path_, + fastdeploy::ModelFormat::PADDLE); + runtime_option.UsePaddleInferBackend(); +#elif defined(USE_ORT_BACKEND) + runtime_option.SetModelPath( + conf.model_file_path_, "", fastdeploy::ModelFormat::ONNX); // onnx + runtime_option.UseOrtBackend(); // onnx +#elif defined(USE_PADDLE_LITE_BACKEND) + runtime_option.SetModelPath(conf.model_file_path_, + conf.param_file_path_, + fastdeploy::ModelFormat::PADDLE); + runtime_option.UseLiteBackend(); +#endif + + runtime_option.SetCpuThreadNum(conf.num_cpu_thread_); + // runtime_option.DeletePaddleBackendPass("simplify_with_basic_ops_pass"); + runtime_ = std::unique_ptr(new fastdeploy::Runtime()); + if (!runtime_->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << conf.model_file_path_ << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << conf.model_file_path_ << std::endl; + } + + Reset(); + return 0; +} + +int ClsNnet::Forward(const char* wav_path, + int topk, + char* result, + int result_max_len) { +#ifdef WITH_PROFILING + kaldi::Timer timer; + timer.Reset(); +#endif + // read wav + std::ifstream infile(wav_path, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 this_channel = 0; + kaldi::Matrix wavform_kaldi = wave_data.Data(); + // only get channel 0 + int wavform_len = wavform_kaldi.NumCols(); + std::vector wavform(wavform_kaldi.Data(), + wavform_kaldi.Data() + wavform_len); + WaveformFloatNormal(&wavform); + WaveformNormal(&wavform, + conf_.wav_normal_, + conf_.wav_normal_type_, + conf_.wav_norm_mul_factor_); +#ifdef PPS_DEBUG + { + std::ofstream fp("cls.wavform", std::ios::out); + for (int i = 0; i < wavform.size(); ++i) { + fp << std::setprecision(18) << wavform[i] << " "; + } + fp << "\n"; + } +#endif +#ifdef WITH_PROFILING + printf("wav read consume: %fs\n", timer.Elapsed()); +#endif + +#ifdef WITH_PROFILING + timer.Reset(); +#endif + + std::vector feats; + std::unique_ptr data_source( + new ppspeech::DataCache()); + ppspeech::Fbank fbank(fbank_opts_, std::move(data_source)); + fbank.Accept(wavform); + fbank.SetFinished(); + fbank.Read(&feats); + + int feat_dim = fbank_opts_.mel_opts.num_bins; + int num_frames = feats.size() / feat_dim; + + for (int i = 0; i < num_frames; ++i) { + for (int j = 0; j < feat_dim; ++j) { + feats[i * feat_dim + j] = PowerTodb(feats[i * feat_dim + j]); + } + } +#ifdef PPS_DEBUG + { + std::ofstream fp("cls.feat", std::ios::out); + for (int i = 0; i < num_frames; ++i) { + for (int j = 0; j < feat_dim; ++j) { + fp << std::setprecision(18) << feats[i * feat_dim + j] << " "; + } + fp << "\n"; + } + } +#endif +#ifdef WITH_PROFILING + printf("extract fbank consume: %fs\n", timer.Elapsed()); +#endif + + // infer + std::vector model_out; +#ifdef WITH_PROFILING + timer.Reset(); +#endif + ModelForward(feats.data(), num_frames, feat_dim, &model_out); +#ifdef WITH_PROFILING + printf("fast deploy infer consume: %fs\n", timer.Elapsed()); +#endif +#ifdef PPS_DEBUG + { + std::ofstream fp("cls.logits", std::ios::out); + for (int i = 0; i < model_out.size(); ++i) { + fp << std::setprecision(18) << model_out[i] << "\n"; + } + } +#endif + + // construct result str + ss_ << "{"; + GetTopkResult(topk, model_out); + ss_ << "}"; + + if (result_max_len <= ss_.str().size()) { + printf("result_max_len is short than result len\n"); + } + snprintf(result, result_max_len, "%s", ss_.str().c_str()); + return 0; +} + +int ClsNnet::ModelForward(float* features, + const int num_frames, + const int feat_dim, + std::vector* model_out) { + // init input tensor shape + fastdeploy::TensorInfo info = runtime_->GetInputInfo(0); + info.shape = {1, num_frames, feat_dim}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + input_tensors[0].SetExternalData({1, num_frames, feat_dim}, + fastdeploy::FDDataType::FP32, + static_cast(features)); + + // get input name + input_tensors[0].name = info.name; + + runtime_->Infer(input_tensors, &output_tensors); + + // output_tensors[0].PrintInfo(); + std::vector output_shape = output_tensors[0].Shape(); + model_out->resize(output_shape[0] * output_shape[1]); + memcpy(static_cast(model_out->data()), + output_tensors[0].Data(), + output_shape[0] * output_shape[1] * sizeof(float)); + return 0; +} + +int ClsNnet::GetTopkResult(int k, const std::vector& model_out) { + std::vector values; + std::vector indics; + TopK(model_out, k, &values, &indics); + for (int i = 0; i < k; ++i) { + if (i != 0) { + ss_ << ","; + } + ss_ << "\"" << dict_[indics[i]] << "\":\"" << values[i] << "\""; + } + return 0; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/panns_nnet.h b/runtime/engine/audio_classification/nnet/panns_nnet.h new file mode 100644 index 00000000..3a4a5718 --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_nnet.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common/frontend/data_cache.h" +#include "common/frontend/fbank.h" +#include "common/frontend/feature-fbank.h" +#include "common/frontend/frontend_itf.h" +#include "common/frontend/wave-reader.h" +#include "common/utils/audio_process.h" +#include "common/utils/file_utils.h" +#include "fastdeploy/runtime.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +namespace ppspeech { +struct ClsNnetConf { + // wav + bool wav_normal_; + std::string wav_normal_type_; + float wav_norm_mul_factor_; + // model + std::string model_file_path_; + std::string param_file_path_; + std::string dict_file_path_; + int num_cpu_thread_; + // fbank + float samp_freq; + float frame_length_ms; + float frame_shift_ms; + int num_bins; + float low_freq; + float high_freq; + float dither; +}; + +class ClsNnet { + public: + ClsNnet(); + int Init(const ClsNnetConf& conf); + int Forward(const char* wav_path, + int topk, + char* result, + int result_max_len); + void Reset(); + + private: + int ModelForward(float* features, + const int num_frames, + const int feat_dim, + std::vector* model_out); + int ModelForwardStream(std::vector* feats); + int GetTopkResult(int k, const std::vector& model_out); + + ClsNnetConf conf_; + knf::FbankOptions fbank_opts_; + std::unique_ptr runtime_; + std::vector dict_; + std::stringstream ss_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/panns_nnet_main.cc b/runtime/engine/audio_classification/nnet/panns_nnet_main.cc new file mode 100644 index 00000000..b47753f0 --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_nnet_main.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "audio_classification/nnet/panns_interface.h" + +DEFINE_string(conf_path, "", "config path"); +DEFINE_string(scp_path, "", "wav scp path"); +DEFINE_string(topk, "", "print topk results"); + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + CHECK_GT(FLAGS_conf_path.size(), 0); + CHECK_GT(FLAGS_scp_path.size(), 0); + CHECK_GT(FLAGS_topk.size(), 0); + void* instance = ppspeech::ClsCreateInstance(FLAGS_conf_path.c_str()); + int ret = 0; + // read wav + std::ifstream ifs(FLAGS_scp_path); + std::string line = ""; + int topk = std::atoi(FLAGS_topk.c_str()); + while (getline(ifs, line)) { + // read wav + char result[1024] = {0}; + ret = ppspeech::ClsFeedForward( + instance, line.c_str(), topk, result, 1024); + printf("%s %s\n", line.c_str(), result); + ret = ppspeech::ClsReset(instance); + } + ret = ppspeech::ClsDestroyInstance(instance); + return 0; +} diff --git a/runtime/engine/codelab/CMakeLists.txt b/runtime/engine/codelab/CMakeLists.txt new file mode 100644 index 00000000..13aa5efb --- /dev/null +++ b/runtime/engine/codelab/CMakeLists.txt @@ -0,0 +1,6 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +if(ANDROID) +else() #Unix + add_subdirectory(glog) +endif() \ No newline at end of file diff --git a/speechx/speechx/codelab/README.md b/runtime/engine/codelab/README.md similarity index 100% rename from speechx/speechx/codelab/README.md rename to runtime/engine/codelab/README.md diff --git a/speechx/speechx/codelab/glog/CMakeLists.txt b/runtime/engine/codelab/glog/CMakeLists.txt similarity index 67% rename from speechx/speechx/codelab/glog/CMakeLists.txt rename to runtime/engine/codelab/glog/CMakeLists.txt index 08a98641..492e33c6 100644 --- a/speechx/speechx/codelab/glog/CMakeLists.txt +++ b/runtime/engine/codelab/glog/CMakeLists.txt @@ -1,8 +1,8 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc) -target_link_libraries(glog_main glog) +target_link_libraries(glog_main extern_glog) add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc) -target_link_libraries(glog_logtostderr_main glog) +target_link_libraries(glog_logtostderr_main extern_glog) diff --git a/speechx/speechx/codelab/glog/README.md b/runtime/engine/codelab/glog/README.md similarity index 100% rename from speechx/speechx/codelab/glog/README.md rename to runtime/engine/codelab/glog/README.md diff --git a/speechx/speechx/codelab/glog/glog_logtostderr_main.cc b/runtime/engine/codelab/glog/glog_logtostderr_main.cc similarity index 100% rename from speechx/speechx/codelab/glog/glog_logtostderr_main.cc rename to runtime/engine/codelab/glog/glog_logtostderr_main.cc diff --git a/speechx/speechx/codelab/glog/glog_main.cc b/runtime/engine/codelab/glog/glog_main.cc similarity index 100% rename from speechx/speechx/codelab/glog/glog_main.cc rename to runtime/engine/codelab/glog/glog_main.cc diff --git a/runtime/engine/common/CMakeLists.txt b/runtime/engine/common/CMakeLists.txt new file mode 100644 index 00000000..405479ae --- /dev/null +++ b/runtime/engine/common/CMakeLists.txt @@ -0,0 +1,19 @@ +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/../ +) +add_subdirectory(base) +add_subdirectory(utils) +add_subdirectory(matrix) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR}/frontend +) +add_subdirectory(frontend) + +add_library(common INTERFACE) +target_link_libraries(common INTERFACE base utils kaldi-matrix frontend) +install(TARGETS base DESTINATION lib) +install(TARGETS utils DESTINATION lib) +install(TARGETS kaldi-matrix DESTINATION lib) +install(TARGETS frontend DESTINATION lib) \ No newline at end of file diff --git a/runtime/engine/common/base/CMakeLists.txt b/runtime/engine/common/base/CMakeLists.txt new file mode 100644 index 00000000..b17131b5 --- /dev/null +++ b/runtime/engine/common/base/CMakeLists.txt @@ -0,0 +1,43 @@ + + +if(WITH_ASR) + add_compile_options(-DWITH_ASR) + set(PPS_FLAGS_LIB "fst/flags.h") +else() + set(PPS_FLAGS_LIB "gflags/gflags.h") +endif() + +if(ANDROID) + set(PPS_GLOG_LIB "base/log_impl.h") +else() #UNIX + if(WITH_ASR) + set(PPS_GLOG_LIB "fst/log.h") + else() + set(PPS_GLOG_LIB "glog/logging.h") + endif() +endif() + +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/flags.h.in + ${CMAKE_CURRENT_SOURCE_DIR}/flags.h @ONLY + ) +message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/flags.h") + +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/log.h.in + ${CMAKE_CURRENT_SOURCE_DIR}/log.h @ONLY + ) +message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/log.h") + + +if(ANDROID) + set(csrc + log_impl.cc + glog_utils.cc + ) + add_library(base ${csrc}) + target_link_libraries(base gflags) +else() # UNIX + set(csrc) + add_library(base INTERFACE) +endif() \ No newline at end of file diff --git a/speechx/speechx/base/basic_types.h b/runtime/engine/common/base/basic_types.h similarity index 100% rename from speechx/speechx/base/basic_types.h rename to runtime/engine/common/base/basic_types.h diff --git a/speechx/speechx/base/common.h b/runtime/engine/common/base/common.h similarity index 93% rename from speechx/speechx/base/common.h rename to runtime/engine/common/base/common.h index 97bff966..b31fc53e 100644 --- a/speechx/speechx/base/common.h +++ b/runtime/engine/common/base/common.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -48,4 +50,5 @@ #include "base/log.h" #include "base/macros.h" #include "utils/file_utils.h" -#include "utils/math.h" \ No newline at end of file +#include "utils/math.h" +#include "utils/timer.h" \ No newline at end of file diff --git a/runtime/engine/common/base/config.h b/runtime/engine/common/base/config.h new file mode 100644 index 00000000..c8eae5e2 --- /dev/null +++ b/runtime/engine/common/base/config.h @@ -0,0 +1,343 @@ +// Copyright (c) code is from +// https://blog.csdn.net/huixingshao/article/details/45969887. + +#include +#include +#include +#include +#include +using namespace std; + +#pragma once + +#ifdef _MSC_VER +#pragma region ParseIniFile +#endif + +/* + * \brief Generic configuration Class + * + */ +class Config { + // Data + protected: + std::string m_Delimiter; //!< separator between key and value + std::string m_Comment; //!< separator between value and comments + std::map + m_Contents; //!< extracted keys and values + + typedef std::map::iterator mapi; + typedef std::map::const_iterator mapci; + // Methods + public: + Config(std::string filename, + std::string delimiter = "=", + std::string comment = "#"); + Config(); + template + T Read(const std::string& in_key) const; //!< Search for key and read value + //! or optional default value, call + //! as read + template + T Read(const std::string& in_key, const T& in_value) const; + template + bool ReadInto(T* out_var, const std::string& in_key) const; + template + bool ReadInto(T* out_var, + const std::string& in_key, + const T& in_value) const; + bool FileExist(std::string filename); + void ReadFile(std::string filename, + std::string delimiter = "=", + std::string comment = "#"); + + // Check whether key exists in configuration + bool KeyExists(const std::string& in_key) const; + + // Modify keys and values + template + void Add(const std::string& in_key, const T& in_value); + void Remove(const std::string& in_key); + + // Check or change configuration syntax + std::string GetDelimiter() const { return m_Delimiter; } + std::string GetComment() const { return m_Comment; } + std::string SetDelimiter(const std::string& in_s) { + std::string old = m_Delimiter; + m_Delimiter = in_s; + return old; + } + std::string SetComment(const std::string& in_s) { + std::string old = m_Comment; + m_Comment = in_s; + return old; + } + + // Write or read configuration + friend std::ostream& operator<<(std::ostream& os, const Config& cf); + friend std::istream& operator>>(std::istream& is, Config& cf); + + protected: + template + static std::string T_as_string(const T& t); + template + static T string_as_T(const std::string& s); + static void Trim(std::string* inout_s); + + + // Exception types + public: + struct File_not_found { + std::string filename; + explicit File_not_found(const std::string& filename_ = std::string()) + : filename(filename_) {} + }; + struct Key_not_found { // thrown only by T read(key) variant of read() + std::string key; + explicit Key_not_found(const std::string& key_ = std::string()) + : key(key_) {} + }; +}; + +/* static */ +template +std::string Config::T_as_string(const T& t) { + // Convert from a T to a string + // Type T must support << operator + std::ostringstream ost; + ost << t; + return ost.str(); +} + + +/* static */ +template +T Config::string_as_T(const std::string& s) { + // Convert from a string to a T + // Type T must support >> operator + T t; + std::istringstream ist(s); + ist >> t; + return t; +} + + +/* static */ +template <> +inline std::string Config::string_as_T(const std::string& s) { + // Convert from a string to a string + // In other words, do nothing + return s; +} + + +/* static */ +template <> +inline bool Config::string_as_T(const std::string& s) { + // Convert from a string to a bool + // Interpret "false", "F", "no", "n", "0" as false + // Interpret "true", "T", "yes", "y", "1", "-1", or anything else as true + bool b = true; + std::string sup = s; + for (std::string::iterator p = sup.begin(); p != sup.end(); ++p) + *p = toupper(*p); // make string all caps + if (sup == std::string("FALSE") || sup == std::string("F") || + sup == std::string("NO") || sup == std::string("N") || + sup == std::string("0") || sup == std::string("NONE")) + b = false; + return b; +} + + +template +T Config::Read(const std::string& key) const { + // Read the value corresponding to key + mapci p = m_Contents.find(key); + if (p == m_Contents.end()) throw Key_not_found(key); + return string_as_T(p->second); +} + + +template +T Config::Read(const std::string& key, const T& value) const { + // Return the value corresponding to key or given default value + // if key is not found + mapci p = m_Contents.find(key); + if (p == m_Contents.end()) { + printf("%s = %s(default)\n", key.c_str(), T_as_string(value).c_str()); + return value; + } else { + printf("%s = %s\n", key.c_str(), T_as_string(p->second).c_str()); + return string_as_T(p->second); + } +} + + +template +bool Config::ReadInto(T* var, const std::string& key) const { + // Get the value corresponding to key and store in var + // Return true if key is found + // Otherwise leave var untouched + mapci p = m_Contents.find(key); + bool found = (p != m_Contents.end()); + if (found) *var = string_as_T(p->second); + return found; +} + + +template +bool Config::ReadInto(T* var, const std::string& key, const T& value) const { + // Get the value corresponding to key and store in var + // Return true if key is found + // Otherwise set var to given default + mapci p = m_Contents.find(key); + bool found = (p != m_Contents.end()); + if (found) + *var = string_as_T(p->second); + else + var = value; + return found; +} + + +template +void Config::Add(const std::string& in_key, const T& value) { + // Add a key with given value + std::string v = T_as_string(value); + std::string key = in_key; + Trim(&key); + Trim(&v); + m_Contents[key] = v; + return; +} + +Config::Config(string filename, string delimiter, string comment) + : m_Delimiter(delimiter), m_Comment(comment) { + // Construct a Config, getting keys and values from given file + + std::ifstream in(filename.c_str()); + + if (!in) throw File_not_found(filename); + + in >> (*this); +} + + +Config::Config() : m_Delimiter(string(1, '=')), m_Comment(string(1, '#')) { + // Construct a Config without a file; empty +} + + +bool Config::KeyExists(const string& key) const { + // Indicate whether key is found + mapci p = m_Contents.find(key); + return (p != m_Contents.end()); +} + + +/* static */ +void Config::Trim(string* inout_s) { + // Remove leading and trailing whitespace + static const char whitespace[] = " \n\t\v\r\f"; + inout_s->erase(0, inout_s->find_first_not_of(whitespace)); + inout_s->erase(inout_s->find_last_not_of(whitespace) + 1U); +} + + +std::ostream& operator<<(std::ostream& os, const Config& cf) { + // Save a Config to os + for (Config::mapci p = cf.m_Contents.begin(); p != cf.m_Contents.end(); + ++p) { + os << p->first << " " << cf.m_Delimiter << " "; + os << p->second << std::endl; + } + return os; +} + +void Config::Remove(const string& key) { + // Remove key and its value + m_Contents.erase(m_Contents.find(key)); + return; +} + +std::istream& operator>>(std::istream& is, Config& cf) { + // Load a Config from is + // Read in keys and values, keeping internal whitespace + typedef string::size_type pos; + const string& delim = cf.m_Delimiter; // separator + const string& comm = cf.m_Comment; // comment + const pos skip = delim.length(); // length of separator + + string nextline = ""; // might need to read ahead to see where value ends + + while (is || nextline.length() > 0) { + // Read an entire line at a time + string line; + if (nextline.length() > 0) { + line = nextline; // we read ahead; use it now + nextline = ""; + } else { + std::getline(is, line); + } + + // Ignore comments + line = line.substr(0, line.find(comm)); + + // Parse the line if it contains a delimiter + pos delimPos = line.find(delim); + if (delimPos < string::npos) { + // Extract the key + string key = line.substr(0, delimPos); + line.replace(0, delimPos + skip, ""); + + // See if value continues on the next line + // Stop at blank line, next line with a key, end of stream, + // or end of file sentry + bool terminate = false; + while (!terminate && is) { + std::getline(is, nextline); + terminate = true; + + string nlcopy = nextline; + Config::Trim(&nlcopy); + if (nlcopy == "") continue; + + nextline = nextline.substr(0, nextline.find(comm)); + if (nextline.find(delim) != string::npos) continue; + + nlcopy = nextline; + Config::Trim(&nlcopy); + if (nlcopy != "") line += "\n"; + line += nextline; + terminate = false; + } + + // Store key and value + Config::Trim(&key); + Config::Trim(&line); + cf.m_Contents[key] = line; // overwrites if key is repeated + } + } + + return is; +} +bool Config::FileExist(std::string filename) { + bool exist = false; + std::ifstream in(filename.c_str()); + if (in) exist = true; + return exist; +} + +void Config::ReadFile(string filename, string delimiter, string comment) { + m_Delimiter = delimiter; + m_Comment = comment; + std::ifstream in(filename.c_str()); + + if (!in) throw File_not_found(filename); + + in >> (*this); +} + +#ifdef _MSC_VER +#pragma endregion ParseIniFIle +#endif diff --git a/speechx/speechx/base/log.h b/runtime/engine/common/base/flags.h.in similarity index 95% rename from speechx/speechx/base/log.h rename to runtime/engine/common/base/flags.h.in index c613b98c..fd265abc 100644 --- a/speechx/speechx/base/log.h +++ b/runtime/engine/common/base/flags.h.in @@ -14,4 +14,4 @@ #pragma once -#include "fst/log.h" +#include "@PPS_FLAGS_LIB@" diff --git a/runtime/engine/common/base/glog_utils.cc b/runtime/engine/common/base/glog_utils.cc new file mode 100644 index 00000000..4ab3c251 --- /dev/null +++ b/runtime/engine/common/base/glog_utils.cc @@ -0,0 +1,12 @@ + +#include "base/glog_utils.h" + +namespace google { +void InitGoogleLogging(const char* name) { + LOG(INFO) << "dummpy InitGoogleLogging."; +} + +void InstallFailureSignalHandler() { + LOG(INFO) << "dummpy InstallFailureSignalHandler."; +} +} // namespace google diff --git a/runtime/engine/common/base/glog_utils.h b/runtime/engine/common/base/glog_utils.h new file mode 100644 index 00000000..9cffcafb --- /dev/null +++ b/runtime/engine/common/base/glog_utils.h @@ -0,0 +1,9 @@ +#pragma once + +#include "base/common.h" + +namespace google { +void InitGoogleLogging(const char* name); + +void InstallFailureSignalHandler(); +} // namespace google \ No newline at end of file diff --git a/speechx/speechx/base/flags.h b/runtime/engine/common/base/log.h.in similarity index 96% rename from speechx/speechx/base/flags.h rename to runtime/engine/common/base/log.h.in index 41df0d45..5d121add 100644 --- a/speechx/speechx/base/flags.h +++ b/runtime/engine/common/base/log.h.in @@ -14,4 +14,4 @@ #pragma once -#include "fst/flags.h" +#include "@PPS_GLOG_LIB@" diff --git a/runtime/engine/common/base/log_impl.cc b/runtime/engine/common/base/log_impl.cc new file mode 100644 index 00000000..d8295590 --- /dev/null +++ b/runtime/engine/common/base/log_impl.cc @@ -0,0 +1,105 @@ +#include "base/log.h" + +DEFINE_int32(logtostderr, 0, "logging to stderr"); + +namespace ppspeech { + +static char __progname[] = "paddlespeech"; + +namespace log { + +std::mutex LogMessage::lock_; +std::string LogMessage::s_debug_logfile_(""); +std::string LogMessage::s_info_logfile_(""); +std::string LogMessage::s_warning_logfile_(""); +std::string LogMessage::s_error_logfile_(""); +std::string LogMessage::s_fatal_logfile_(""); + +void LogMessage::get_curr_proc_info(std::string* pid, std::string* proc_name) { + std::stringstream ss; + ss << getpid(); + ss >> *pid; + *proc_name = ::ppspeech::__progname; +} + +LogMessage::LogMessage(const char* file, + int line, + Severity level, + bool verbose, + bool out_to_file /* = false */) + : level_(level), verbose_(verbose), out_to_file_(out_to_file) { + if (FLAGS_logtostderr == 0) { + stream_ = static_cast(&std::cout); + } else if (FLAGS_logtostderr == 1) { + stream_ = static_cast(&std::cerr); + } else if (out_to_file_) { + // logfile + lock_.lock(); + init(file, line); + } +} + +LogMessage::~LogMessage() { + stream() << std::endl; + + if (out_to_file_) { + lock_.unlock(); + } + + if (verbose_ && level_ == FATAL) { + std::abort(); + } +} + +std::ostream* LogMessage::nullstream() { + thread_local static std::ofstream os; + thread_local static bool flag_set = false; + if (!flag_set) { + os.setstate(std::ios_base::badbit); + flag_set = true; + } + return &os; +} + +void LogMessage::init(const char* file, int line) { + time_t t = time(0); + char tmp[100]; + strftime(tmp, sizeof(tmp), "%Y%m%d-%H%M%S", localtime(&t)); + + if (s_info_logfile_.empty()) { + std::string pid; + std::string proc_name; + get_curr_proc_info(&pid, &proc_name); + + s_debug_logfile_ = + std::string("log." + proc_name + ".log.DEBUG." + tmp + "." + pid); + s_info_logfile_ = + std::string("log." + proc_name + ".log.INFO." + tmp + "." + pid); + s_warning_logfile_ = + std::string("log." + proc_name + ".log.WARNING." + tmp + "." + pid); + s_error_logfile_ = + std::string("log." + proc_name + ".log.ERROR." + tmp + "." + pid); + s_fatal_logfile_ = + std::string("log." + proc_name + ".log.FATAL." + tmp + "." + pid); + } + + thread_local static std::ofstream ofs; + if (level_ == DEBUG) { + ofs.open(s_debug_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == INFO) { + ofs.open(s_info_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == WARNING) { + ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == ERROR) { + ofs.open(s_error_logfile_.c_str(), std::ios::out | std::ios::app); + } else { + ofs.open(s_fatal_logfile_.c_str(), std::ios::out | std::ios::app); + } + + stream_ = &ofs; + + stream() << tmp << " " << file << " line " << line << "; "; + stream() << std::flush; +} +} // namespace log +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/base/log_impl.h b/runtime/engine/common/base/log_impl.h new file mode 100644 index 00000000..fd6cce19 --- /dev/null +++ b/runtime/engine/common/base/log_impl.h @@ -0,0 +1,173 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from https://github.com/Dounm/dlog +// modified form +// https://android.googlesource.com/platform/art/+/806defa/src/logging.h + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "base/common.h" +#include "base/macros.h" +#ifndef WITH_GLOG +#include "base/glog_utils.h" +#endif + +DECLARE_int32(logtostderr); + +namespace ppspeech { + +namespace log { + +enum Severity { + DEBUG, + INFO, + WARNING, + ERROR, + FATAL, + NUM_SEVERITIES, +}; + +class LogMessage { + public: + static void get_curr_proc_info(std::string* pid, std::string* proc_name); + + LogMessage(const char* file, + int line, + Severity level, + bool verbose, + bool out_to_file = false); + + ~LogMessage(); + + std::ostream& stream() { return verbose_ ? *stream_ : *nullstream(); } + + private: + void init(const char* file, int line); + std::ostream* nullstream(); + + private: + std::ostream* stream_; + std::ostream* null_stream_; + Severity level_; + bool verbose_; + bool out_to_file_; + + static std::mutex lock_; // stream write lock + static std::string s_debug_logfile_; + static std::string s_info_logfile_; + static std::string s_warning_logfile_; + static std::string s_error_logfile_; + static std::string s_fatal_logfile_; + + DISALLOW_COPY_AND_ASSIGN(LogMessage); +}; + + +} // namespace log + +} // namespace ppspeech + + +#ifndef PPS_DEBUG +#define DLOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, false) +#define DLOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, false) +#define DLOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, false) +#define DLOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, false) +#else +#define DLOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true) +#define DLOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true) +#define DLOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true) +#define DLOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true) +#endif + + +#define LOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true) +#define LOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true) +#define LOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true) +#define LOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true) + + +#define LOG_0 LOG_DEBUG +#define LOG_1 LOG_INFO +#define LOG_2 LOG_WARNING +#define LOG_3 LOG_ERROR +#define LOG_4 LOG_FATAL + +#define LOG(level) LOG_##level.stream() + +#define DLOG(level) DLOG_##level.stream() + +#define VLOG(verboselevel) LOG(verboselevel) + +#define CHECK(exp) \ + ppspeech::log::LogMessage( \ + __FILE__, __LINE__, ppspeech::log::FATAL, !(exp)) \ + .stream() \ + << "Check Failed: " #exp + +#define CHECK_EQ(x, y) CHECK((x) == (y)) +#define CHECK_NE(x, y) CHECK((x) != (y)) +#define CHECK_LE(x, y) CHECK((x) <= (y)) +#define CHECK_LT(x, y) CHECK((x) < (y)) +#define CHECK_GE(x, y) CHECK((x) >= (y)) +#define CHECK_GT(x, y) CHECK((x) > (y)) +#ifdef PPS_DEBUG +#define DCHECK(x) CHECK(x) +#define DCHECK_EQ(x, y) CHECK_EQ(x, y) +#define DCHECK_NE(x, y) CHECK_NE(x, y) +#define DCHECK_LE(x, y) CHECK_LE(x, y) +#define DCHECK_LT(x, y) CHECK_LT(x, y) +#define DCHECK_GE(x, y) CHECK_GE(x, y) +#define DCHECK_GT(x, y) CHECK_GT(x, y) +#else +#define DCHECK(condition) \ + while (false) CHECK(condition) +#define DCHECK_EQ(val1, val2) \ + while (false) CHECK_EQ(val1, val2) +#define DCHECK_NE(val1, val2) \ + while (false) CHECK_NE(val1, val2) +#define DCHECK_LE(val1, val2) \ + while (false) CHECK_LE(val1, val2) +#define DCHECK_LT(val1, val2) \ + while (false) CHECK_LT(val1, val2) +#define DCHECK_GE(val1, val2) \ + while (false) CHECK_GE(val1, val2) +#define DCHECK_GT(val1, val2) \ + while (false) CHECK_GT(val1, val2) +#define DCHECK_STREQ(str1, str2) \ + while (false) CHECK_STREQ(str1, str2) +#endif \ No newline at end of file diff --git a/speechx/speechx/base/macros.h b/runtime/engine/common/base/macros.h similarity index 100% rename from speechx/speechx/base/macros.h rename to runtime/engine/common/base/macros.h index db989812..e60baf55 100644 --- a/speechx/speechx/base/macros.h +++ b/runtime/engine/common/base/macros.h @@ -17,14 +17,14 @@ #include #include -namespace ppspeech { - #ifndef DISALLOW_COPY_AND_ASSIGN #define DISALLOW_COPY_AND_ASSIGN(TypeName) \ TypeName(const TypeName&) = delete; \ void operator=(const TypeName&) = delete #endif +namespace ppspeech { + // kSpaceSymbol in UTF-8 is: ▁ const char kSpaceSymbo[] = "\xe2\x96\x81"; diff --git a/runtime/engine/common/base/safe_queue.h b/runtime/engine/common/base/safe_queue.h new file mode 100644 index 00000000..25a012af --- /dev/null +++ b/runtime/engine/common/base/safe_queue.h @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/common.h" + +namespace ppspeech { + +template +class SafeQueue { + public: + explicit SafeQueue(size_t capacity = 0); + void push_back(const T& in); + bool pop(T* out); + bool empty() const { return buffer_.empty(); } + size_t size() const { return buffer_.size(); } + void clear(); + + + private: + std::mutex mutex_; + std::condition_variable condition_; + std::deque buffer_; + size_t capacity_; +}; + +template +SafeQueue::SafeQueue(size_t capacity) : capacity_(capacity) {} + +template +void SafeQueue::push_back(const T& in) { + std::unique_lock lock(mutex_); + if (capacity_ > 0 && buffer_.size() == capacity_) { + condition_.wait(lock, [this] { return capacity_ >= buffer_.size(); }); + } + + buffer_.push_back(in); + condition_.notify_one(); +} + +template +bool SafeQueue::pop(T* out) { + if (buffer_.empty()) { + return false; + } + + std::unique_lock lock(mutex_); + condition_.wait(lock, [this] { return buffer_.size() > 0; }); + *out = std::move(buffer_.front()); + buffer_.pop_front(); + condition_.notify_one(); + return true; +} + +template +void SafeQueue::clear() { + std::unique_lock lock(mutex_); + buffer_.clear(); + condition_.notify_one(); +} +} // namespace ppspeech diff --git a/speechx/speechx/frontend/text/CMakeLists.txt b/runtime/engine/common/base/safe_queue_inl.h similarity index 100% rename from speechx/speechx/frontend/text/CMakeLists.txt rename to runtime/engine/common/base/safe_queue_inl.h diff --git a/speechx/speechx/base/thread_pool.h b/runtime/engine/common/base/thread_pool.h similarity index 100% rename from speechx/speechx/base/thread_pool.h rename to runtime/engine/common/base/thread_pool.h diff --git a/runtime/engine/common/frontend/CMakeLists.txt b/runtime/engine/common/frontend/CMakeLists.txt new file mode 100644 index 00000000..0b95b650 --- /dev/null +++ b/runtime/engine/common/frontend/CMakeLists.txt @@ -0,0 +1,31 @@ +add_library(kaldi-native-fbank-core + feature-fbank.cc + feature-functions.cc + feature-window.cc + fftsg.c + mel-computations.cc + rfft.cc +) +target_link_libraries(kaldi-native-fbank-core PUBLIC utils base) +target_compile_options(kaldi-native-fbank-core PUBLIC "-fPIC") + +add_library(frontend STATIC + cmvn.cc + audio_cache.cc + feature_cache.cc + feature_pipeline.cc + assembler.cc + wave-reader.cc +) +target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils base) + +set(BINS + compute_fbank_main +) + +foreach(bin_name IN LISTS BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + # https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207 + target_link_libraries(${bin_name} PUBLIC frontend base utils kaldi-util libgflags_nothreads.so Threads::Threads extern_glog) +endforeach() diff --git a/speechx/speechx/frontend/audio/assembler.cc b/runtime/engine/common/frontend/assembler.cc similarity index 75% rename from speechx/speechx/frontend/audio/assembler.cc rename to runtime/engine/common/frontend/assembler.cc index 9d5fc403..ba46e1ca 100644 --- a/speechx/speechx/frontend/audio/assembler.cc +++ b/runtime/engine/common/frontend/assembler.cc @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/assembler.h" +#include "frontend/assembler.h" namespace ppspeech { using kaldi::BaseFloat; -using kaldi::Vector; -using kaldi::VectorBase; using std::unique_ptr; +using std::vector; Assembler::Assembler(AssemblerOptions opts, unique_ptr base_extractor) { @@ -33,13 +32,13 @@ Assembler::Assembler(AssemblerOptions opts, dim_ = base_extractor_->Dim(); } -void Assembler::Accept(const kaldi::VectorBase& inputs) { +void Assembler::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); } // pop feature chunk -bool Assembler::Read(kaldi::Vector* feats) { +bool Assembler::Read(std::vector* feats) { kaldi::Timer timer; bool result = Compute(feats); VLOG(1) << "Assembler::Read cost: " << timer.Elapsed() << " sec."; @@ -47,40 +46,37 @@ bool Assembler::Read(kaldi::Vector* feats) { } // read frame by frame from base_feature_extractor_ into cache_ -bool Assembler::Compute(Vector* feats) { +bool Assembler::Compute(vector* feats) { // compute and feed frame by frame while (feature_cache_.size() < frame_chunk_size_) { - Vector feature; + vector feature; bool result = base_extractor_->Read(&feature); - if (result == false || feature.Dim() == 0) { - VLOG(3) << "result: " << result - << " feature dim: " << feature.Dim(); + if (result == false || feature.size() == 0) { + VLOG(1) << "result: " << result + << " feature dim: " << feature.size(); if (IsFinished() == false) { - VLOG(3) << "finished reading feature. cache size: " + VLOG(1) << "finished reading feature. cache size: " << feature_cache_.size(); return false; } else { - VLOG(3) << "break"; + VLOG(1) << "break"; break; } } - - CHECK(feature.Dim() == dim_); feature_cache_.push(feature); - nframes_ += 1; - VLOG(3) << "nframes: " << nframes_; + VLOG(1) << "nframes: " << nframes_; } if (feature_cache_.size() < receptive_filed_length_) { - VLOG(3) << "feature_cache less than receptive_filed_lenght. " + VLOG(3) << "feature_cache less than receptive_filed_length. " << feature_cache_.size() << ": " << receptive_filed_length_; return false; } if (fill_zero_) { while (feature_cache_.size() < frame_chunk_size_) { - Vector feature(dim_, kaldi::kSetZero); + vector feature(dim_, kaldi::kSetZero); nframes_ += 1; feature_cache_.push(feature); } @@ -88,16 +84,17 @@ bool Assembler::Compute(Vector* feats) { int32 this_chunk_size = std::min(static_cast(feature_cache_.size()), frame_chunk_size_); - feats->Resize(dim_ * this_chunk_size); + feats->resize(dim_ * this_chunk_size); VLOG(3) << "read " << this_chunk_size << " feat."; int32 counter = 0; while (counter < this_chunk_size) { - Vector& val = feature_cache_.front(); - CHECK(val.Dim() == dim_) << val.Dim(); + vector& val = feature_cache_.front(); + CHECK(val.size() == dim_) << val.size(); int32 start = counter * dim_; - feats->Range(start, dim_).CopyFromVec(val); + std::memcpy( + feats->data() + start, val.data(), val.size() * sizeof(BaseFloat)); if (this_chunk_size - counter <= cache_size_) { feature_cache_.push(val); @@ -115,7 +112,7 @@ bool Assembler::Compute(Vector* feats) { void Assembler::Reset() { - std::queue> empty; + std::queue> empty; std::swap(feature_cache_, empty); nframes_ = 0; base_extractor_->Reset(); diff --git a/speechx/speechx/frontend/audio/assembler.h b/runtime/engine/common/frontend/assembler.h similarity index 86% rename from speechx/speechx/frontend/audio/assembler.h rename to runtime/engine/common/frontend/assembler.h index 72e6f635..9ec28053 100644 --- a/speechx/speechx/frontend/audio/assembler.h +++ b/runtime/engine/common/frontend/assembler.h @@ -15,7 +15,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { @@ -36,10 +36,10 @@ class Assembler : public FrontendInterface { std::unique_ptr base_extractor = NULL); // Feed feats or waves - void Accept(const kaldi::VectorBase& inputs) override; + void Accept(const std::vector& inputs) override; // feats size = num_frames * feat_dim - bool Read(kaldi::Vector* feats) override; + bool Read(std::vector* feats) override; // feat dim size_t Dim() const override { return dim_; } @@ -51,7 +51,7 @@ class Assembler : public FrontendInterface { void Reset() override; private: - bool Compute(kaldi::Vector* feats); + bool Compute(std::vector* feats); bool fill_zero_{false}; @@ -60,7 +60,7 @@ class Assembler : public FrontendInterface { int32 frame_chunk_stride_; // stride int32 cache_size_; // window - stride int32 receptive_filed_length_; - std::queue> feature_cache_; + std::queue> feature_cache_; std::unique_ptr base_extractor_; int32 nframes_; // num frame computed diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/runtime/engine/common/frontend/audio_cache.cc similarity index 63% rename from speechx/speechx/frontend/audio/audio_cache.cc rename to runtime/engine/common/frontend/audio_cache.cc index c6a91f4b..7ff1c4c4 100644 --- a/speechx/speechx/frontend/audio/audio_cache.cc +++ b/runtime/engine/common/frontend/audio_cache.cc @@ -12,15 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/audio_cache.h" +#include "frontend/audio_cache.h" #include "kaldi/base/timer.h" namespace ppspeech { using kaldi::BaseFloat; -using kaldi::Vector; -using kaldi::VectorBase; +using std::vector; AudioCache::AudioCache(int buffer_size, bool to_float32) : finished_(false), @@ -37,53 +36,39 @@ BaseFloat AudioCache::Convert2PCM32(BaseFloat val) { return val * (1. / std::pow(2.0, 15)); } -void AudioCache::Accept(const VectorBase& waves) { +void AudioCache::Accept(const vector& waves) { kaldi::Timer timer; std::unique_lock lock(mutex_); - while (size_ + waves.Dim() > ring_buffer_.size()) { + while (size_ + waves.size() > ring_buffer_.size()) { ready_feed_condition_.wait(lock); } - for (size_t idx = 0; idx < waves.Dim(); ++idx) { + for (size_t idx = 0; idx < waves.size(); ++idx) { int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size(); - ring_buffer_[buffer_idx] = waves(idx); - if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); + ring_buffer_[buffer_idx] = waves[idx]; + if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves[idx]); } - size_ += waves.Dim(); + size_ += waves.size(); VLOG(1) << "AudioCache::Accept cost: " << timer.Elapsed() << " sec. " - << waves.Dim() << " samples."; + << waves.size() << " samples."; } -bool AudioCache::Read(Vector* waves) { +bool AudioCache::Read(vector* waves) { kaldi::Timer timer; - size_t chunk_size = waves->Dim(); + size_t chunk_size = waves->size(); std::unique_lock lock(mutex_); - while (chunk_size > size_) { - // when audio is empty and no more data feed - // ready_read_condition will block in dead lock, - // so replace with timeout_ - // ready_read_condition_.wait(lock); - int32 elapsed = static_cast(timer.Elapsed() * 1000); - if (elapsed > timeout_) { - if (finished_ == true) { - // read last chunk data - break; - } - if (chunk_size > size_) { - return false; - } - } - usleep(100); // sleep 0.1 ms - } - - // read last chunk data if (chunk_size > size_) { - chunk_size = size_; - waves->Resize(chunk_size); + if (finished_ == false) { + return false; + } else { + // read last chunk data + chunk_size = size_; + waves->resize(chunk_size); + } } for (size_t idx = 0; idx < chunk_size; ++idx) { int buff_idx = (offset_ + idx) % ring_buffer_.size(); - waves->Data()[idx] = ring_buffer_[buff_idx]; + waves->at(idx) = ring_buffer_[buff_idx]; } size_ -= chunk_size; offset_ = (offset_ + chunk_size) % ring_buffer_.size(); diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/runtime/engine/common/frontend/audio_cache.h similarity index 89% rename from speechx/speechx/frontend/audio/audio_cache.h rename to runtime/engine/common/frontend/audio_cache.h index 4708a6e0..fdc4fdf4 100644 --- a/speechx/speechx/frontend/audio/audio_cache.h +++ b/runtime/engine/common/frontend/audio_cache.h @@ -16,7 +16,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { @@ -26,9 +26,9 @@ class AudioCache : public FrontendInterface { explicit AudioCache(int buffer_size = 1000 * kint16max, bool to_float32 = false); - virtual void Accept(const kaldi::VectorBase& waves); + virtual void Accept(const std::vector& waves); - virtual bool Read(kaldi::Vector* waves); + virtual bool Read(std::vector* waves); // the audio dim is 1, one sample, which is useless, // so we return size_(cache samples) instead. @@ -39,7 +39,7 @@ class AudioCache : public FrontendInterface { finished_ = true; } - virtual bool IsFinished() const { return finished_; } + virtual bool IsFinished() const { return finished_ && (size_ == 0); } void Reset() override { offset_ = 0; diff --git a/runtime/engine/common/frontend/cmvn.cc b/runtime/engine/common/frontend/cmvn.cc new file mode 100644 index 00000000..0f110820 --- /dev/null +++ b/runtime/engine/common/frontend/cmvn.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "frontend/cmvn.h" + +#include "utils/file_utils.h" +#include "utils/picojson.h" + +namespace ppspeech { + +using kaldi::BaseFloat; +using std::unique_ptr; +using std::vector; + + +CMVN::CMVN(std::string cmvn_file, unique_ptr base_extractor) + : var_norm_(true) { + CHECK_NE(cmvn_file, ""); + base_extractor_ = std::move(base_extractor); + ReadCMVNFromJson(cmvn_file); + dim_ = mean_stats_.size() - 1; +} + +void CMVN::ReadCMVNFromJson(std::string cmvn_file) { + std::string json_str = ppspeech::ReadFile2String(cmvn_file); + picojson::value value; + std::string err; + const char* json_end = picojson::parse( + value, json_str.c_str(), json_str.c_str() + json_str.size(), &err); + if (!value.is()) { + LOG(ERROR) << "Input json file format error."; + } + const picojson::value::array& mean_stat = + value.get("mean_stat").get(); + for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) { + mean_stats_.push_back((*it).get()); + } + + const picojson::value::array& var_stat = + value.get("var_stat").get(); + for (auto it = var_stat.begin(); it != var_stat.end(); it++) { + var_stats_.push_back((*it).get()); + } + + kaldi::int32 frame_num = value.get("frame_num").get(); + LOG(INFO) << "nframe: " << frame_num; + mean_stats_.push_back(frame_num); + var_stats_.push_back(0); +} + +void CMVN::Accept(const std::vector& inputs) { + // feed waves/feats to compute feature + base_extractor_->Accept(inputs); + return; +} + +bool CMVN::Read(std::vector* feats) { + // compute feature + if (base_extractor_->Read(feats) == false || feats->size() == 0) { + return false; + } + + // appply cmvn + kaldi::Timer timer; + Compute(feats); + VLOG(1) << "CMVN::Read cost: " << timer.Elapsed() << " sec."; + return true; +} + +// feats contain num_frames feature. +void CMVN::Compute(vector* feats) const { + KALDI_ASSERT(feats != NULL); + + if (feats->size() % dim_ != 0) { + LOG(ERROR) << "Dim mismatch: cmvn " << mean_stats_.size() << ',' + << var_stats_.size() - 1 << ", feats " << feats->size() + << 'x'; + } + if (var_stats_.size() == 0 && var_norm_) { + LOG(ERROR) + << "You requested variance normalization but no variance stats_ " + << "are supplied."; + } + + double count = mean_stats_[dim_]; + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats_, we use a count of one. + if (count < 1.0) + LOG(ERROR) << "Insufficient stats_ for cepstral mean and variance " + "normalization: " + << "count = " << count; + + if (!var_norm_) { + vector offset(feats->size()); + vector mean_stats(mean_stats_); + for (size_t i = 0; i < mean_stats.size(); ++i) { + mean_stats[i] /= count; + } + vector mean_stats_apply(feats->size()); + // fill the datat of mean_stats in mean_stats_appy whose dim_ is equal + // with the dim_ of feature. + // the dim_ of feats = dim_ * num_frames; + for (int32 idx = 0; idx < feats->size() / dim_; ++idx) { + std::memcpy(mean_stats_apply.data() + dim_ * idx, + mean_stats.data(), + dim_ * sizeof(double)); + } + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) += offset[idx]; + } + return; + } + // norm(0, d) = mean offset; + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + vector norm0(feats->size()); + vector norm1(feats->size()); + for (int32 d = 0; d < dim_; d++) { + double mean, offset, scale; + mean = mean_stats_[d] / count; + double var = (var_stats_[d] / count) - mean * mean, floor = 1.0e-20; + if (var < floor) { + LOG(WARNING) << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + scale = 1.0 / sqrt(var); + if (scale != scale || 1 / scale == 0.0) + LOG(ERROR) + << "NaN or infinity in cepstral mean/variance computation"; + offset = -(mean * scale); + for (int32 d_skip = d; d_skip < feats->size();) { + norm0[d_skip] = offset; + norm1[d_skip] = scale; + d_skip = d_skip + dim_; + } + } + // Apply the normalization. + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) *= norm1[idx]; + } + + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) += norm0[idx]; + } +} + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/cmvn.h b/runtime/engine/common/frontend/cmvn.h similarity index 77% rename from speechx/speechx/frontend/audio/cmvn.h rename to runtime/engine/common/frontend/cmvn.h index 50ef5649..c515b6ae 100644 --- a/speechx/speechx/frontend/audio/cmvn.h +++ b/runtime/engine/common/frontend/cmvn.h @@ -15,8 +15,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "frontend/frontend_itf.h" #include "kaldi/util/options-itf.h" namespace ppspeech { @@ -25,11 +24,11 @@ class CMVN : public FrontendInterface { public: explicit CMVN(std::string cmvn_file, std::unique_ptr base_extractor); - virtual void Accept(const kaldi::VectorBase& inputs); + virtual void Accept(const std::vector& inputs); // the length of feats = feature_row * feature_dim, // the Matrix is squashed into Vector - virtual bool Read(kaldi::Vector* feats); + virtual bool Read(std::vector* feats); // the dim_ is the feautre dim. virtual size_t Dim() const { return dim_; } virtual void SetFinished() { base_extractor_->SetFinished(); } @@ -37,9 +36,10 @@ class CMVN : public FrontendInterface { virtual void Reset() { base_extractor_->Reset(); } private: - void Compute(kaldi::VectorBase* feats) const; - void ApplyCMVN(kaldi::MatrixBase* feats); - kaldi::Matrix stats_; + void ReadCMVNFromJson(std::string cmvn_file); + void Compute(std::vector* feats) const; + std::vector mean_stats_; + std::vector var_stats_; std::unique_ptr base_extractor_; size_t dim_; bool var_norm_; diff --git a/speechx/speechx/frontend/audio/compute_fbank_main.cc b/runtime/engine/common/frontend/compute_fbank_main.cc similarity index 89% rename from speechx/speechx/frontend/audio/compute_fbank_main.cc rename to runtime/engine/common/frontend/compute_fbank_main.cc index e2b54a8a..e022207d 100644 --- a/speechx/speechx/frontend/audio/compute_fbank_main.cc +++ b/runtime/engine/common/frontend/compute_fbank_main.cc @@ -16,13 +16,13 @@ #include "base/flags.h" #include "base/log.h" -#include "frontend/audio/audio_cache.h" -#include "frontend/audio/data_cache.h" -#include "frontend/audio/fbank.h" -#include "frontend/audio/feature_cache.h" -#include "frontend/audio/frontend_itf.h" -#include "frontend/audio/normalizer.h" -#include "kaldi/feat/wave-reader.h" +#include "frontend/audio_cache.h" +#include "frontend/data_cache.h" +#include "frontend/fbank.h" +#include "frontend/feature_cache.h" +#include "frontend/frontend_itf.h" +#include "frontend/normalizer.h" +#include "frontend/wave-reader.h" #include "kaldi/util/kaldi-io.h" #include "kaldi/util/table-types.h" @@ -56,7 +56,7 @@ int main(int argc, char* argv[]) { std::unique_ptr data_source( new ppspeech::AudioCache(3600 * 1600, false)); - kaldi::FbankOptions opt; + knf::FbankOptions opt; opt.frame_opts.frame_length_ms = 25; opt.frame_opts.frame_shift_ms = 10; opt.mel_opts.num_bins = FLAGS_num_bins; @@ -73,8 +73,7 @@ int main(int argc, char* argv[]) { new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); // the feature cache output feature chunk by chunk. - ppspeech::FeatureCacheOptions feat_cache_opts; - ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); + ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); LOG(INFO) << "fbank: " << true; LOG(INFO) << "feat dim: " << feature_cache.Dim(); @@ -117,9 +116,9 @@ int main(int argc, char* argv[]) { std::min(chunk_sample_size, tot_samples - sample_offset); // get chunk wav - kaldi::Vector wav_chunk(cur_chunk_size); + std::vector wav_chunk(cur_chunk_size); for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); + wav_chunk[i] = waveform(sample_offset + i); } // compute feat @@ -131,10 +130,14 @@ int main(int argc, char* argv[]) { } // read feat - kaldi::Vector features; + kaldi::Vector features(feature_cache.Dim()); bool flag = true; do { - flag = feature_cache.Read(&features); + std::vector tmp; + flag = feature_cache.Read(&tmp); + std::memcpy(features.Data(), + tmp.data(), + tmp.size() * sizeof(BaseFloat)); if (flag && features.Dim() != 0) { feats.push_back(features); feature_rows += features.Dim() / feature_cache.Dim(); diff --git a/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc b/runtime/engine/common/frontend/compute_linear_spectrogram_main.cc similarity index 100% rename from speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc rename to runtime/engine/common/frontend/compute_linear_spectrogram_main.cc diff --git a/speechx/speechx/frontend/audio/data_cache.h b/runtime/engine/common/frontend/data_cache.h similarity index 79% rename from speechx/speechx/frontend/audio/data_cache.h rename to runtime/engine/common/frontend/data_cache.h index 5fe5e4fe..7a37adf4 100644 --- a/speechx/speechx/frontend/audio/data_cache.h +++ b/runtime/engine/common/frontend/data_cache.h @@ -15,10 +15,10 @@ #pragma once - #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" +using std::vector; namespace ppspeech { @@ -30,16 +30,16 @@ class DataCache : public FrontendInterface { DataCache() : finished_{false}, dim_{0} {} // accept waves/feats - void Accept(const kaldi::VectorBase& inputs) override { - data_ = inputs; + void Accept(const std::vector& inputs) override { + data_ = std::move(inputs); } - bool Read(kaldi::Vector* feats) override { - if (data_.Dim() == 0) { + bool Read(vector* feats) override { + if (data_.size() == 0) { return false; } - (*feats) = data_; - data_.Resize(0); + (*feats) = std::move(data_); + data_.resize(0); return true; } @@ -53,7 +53,7 @@ class DataCache : public FrontendInterface { } private: - kaldi::Vector data_; + std::vector data_; bool finished_; int32 dim_; diff --git a/speechx/speechx/frontend/audio/db_norm.cc b/runtime/engine/common/frontend/db_norm.cc similarity index 100% rename from speechx/speechx/frontend/audio/db_norm.cc rename to runtime/engine/common/frontend/db_norm.cc diff --git a/speechx/speechx/frontend/audio/db_norm.h b/runtime/engine/common/frontend/db_norm.h similarity index 100% rename from speechx/speechx/frontend/audio/db_norm.h rename to runtime/engine/common/frontend/db_norm.h diff --git a/runtime/engine/common/frontend/fbank.h b/runtime/engine/common/frontend/fbank.h new file mode 100644 index 00000000..4398e72f --- /dev/null +++ b/runtime/engine/common/frontend/fbank.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "frontend/feature-fbank.h" +#include "frontend/feature_common.h" + +namespace ppspeech { + +typedef StreamingFeatureTpl Fbank; + +} // namespace ppspeech diff --git a/runtime/engine/common/frontend/feature-fbank.cc b/runtime/engine/common/frontend/feature-fbank.cc new file mode 100644 index 00000000..2393e153 --- /dev/null +++ b/runtime/engine/common/frontend/feature-fbank.cc @@ -0,0 +1,123 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-fbank.cc +// +#include "frontend/feature-fbank.h" + +#include + +#include "frontend/feature-functions.h" + +namespace knf { + +static void Sqrt(float *in_out, int32_t n) { + for (int32_t i = 0; i != n; ++i) { + in_out[i] = std::sqrt(in_out[i]); + } +} + +std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) { + os << opts.ToString(); + return os; +} + +FbankComputer::FbankComputer(const FbankOptions &opts) + : opts_(opts), rfft_(opts.frame_opts.PaddedWindowSize()) { + if (opts.energy_floor > 0.0f) { + log_energy_floor_ = logf(opts.energy_floor); + } + + // We'll definitely need the filterbanks info for VTLN warping factor 1.0. + // [note: this call caches it.] + GetMelBanks(1.0f); +} + +FbankComputer::~FbankComputer() { + for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter) + delete iter->second; +} + +const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) { + MelBanks *this_mel_banks = nullptr; + + // std::map::iterator iter = mel_banks_.find(vtln_warp); + auto iter = mel_banks_.find(vtln_warp); + if (iter == mel_banks_.end()) { + this_mel_banks = + new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp); + mel_banks_[vtln_warp] = this_mel_banks; + } else { + this_mel_banks = iter->second; + } + return this_mel_banks; +} + +void FbankComputer::Compute(float signal_raw_log_energy, + float vtln_warp, + std::vector *signal_frame, + float *feature) { + const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); + + CHECK_EQ(signal_frame->size(), opts_.frame_opts.PaddedWindowSize()); + + // Compute energy after window function (not the raw one). + if (opts_.use_energy && !opts_.raw_energy) { + signal_raw_log_energy = + std::log(std::max(InnerProduct(signal_frame->data(), + signal_frame->data(), + signal_frame->size()), + std::numeric_limits::epsilon())); + } + rfft_.Compute(signal_frame->data()); // signal_frame is modified in-place + ComputePowerSpectrum(signal_frame); + + // Use magnitude instead of power if requested. + if (!opts_.use_power) { + Sqrt(signal_frame->data(), signal_frame->size() / 2 + 1); + } + + int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); + + // Its length is opts_.mel_opts.num_bins + float *mel_energies = feature + mel_offset; + + // Sum with mel filter banks over the power spectrum + mel_banks.Compute(signal_frame->data(), mel_energies); + + if (opts_.use_log_fbank) { + // Avoid log of zero (which should be prevented anyway by dithering). + for (int32_t i = 0; i != opts_.mel_opts.num_bins; ++i) { + auto t = std::max(mel_energies[i], + std::numeric_limits::epsilon()); + mel_energies[i] = std::log(t); + } + } + + // Copy energy as first value (or the last, if htk_compat == true). + if (opts_.use_energy) { + if (opts_.energy_floor > 0.0 && + signal_raw_log_energy < log_energy_floor_) { + signal_raw_log_energy = log_energy_floor_; + } + int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0; + feature[energy_index] = signal_raw_log_energy; + } +} + +} // namespace knf diff --git a/runtime/engine/common/frontend/feature-fbank.h b/runtime/engine/common/frontend/feature-fbank.h new file mode 100644 index 00000000..3dab793f --- /dev/null +++ b/runtime/engine/common/frontend/feature-fbank.h @@ -0,0 +1,138 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-fbank.h + +#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ +#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ + +#include +#include + +#include "frontend/feature-window.h" +#include "frontend/mel-computations.h" +#include "frontend/rfft.h" + +namespace knf { + +struct FbankOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + // append an extra dimension with energy to the filter banks + bool use_energy = false; + float energy_floor = 0.0f; // active iff use_energy==true + + // If true, compute log_energy before preemphasis and windowing + // If false, compute log_energy after preemphasis ans windowing + bool raw_energy = true; // active iff use_energy==true + + // If true, put energy last (if using energy) + // If false, put energy first + bool htk_compat = false; // active iff use_energy==true + + // if true (default), produce log-filterbank, else linear + bool use_log_fbank = true; + + // if true (default), use power in filterbank + // analysis, else magnitude. + bool use_power = true; + + FbankOptions() { mel_opts.num_bins = 23; } + + std::string ToString() const { + std::ostringstream os; + os << "frame_opts: \n"; + os << frame_opts << "\n"; + os << "\n"; + + os << "mel_opts: \n"; + os << mel_opts << "\n"; + + os << "use_energy: " << use_energy << "\n"; + os << "energy_floor: " << energy_floor << "\n"; + os << "raw_energy: " << raw_energy << "\n"; + os << "htk_compat: " << htk_compat << "\n"; + os << "use_log_fbank: " << use_log_fbank << "\n"; + os << "use_power: " << use_power << "\n"; + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const FbankOptions &opts); + +class FbankComputer { + public: + using Options = FbankOptions; + + explicit FbankComputer(const FbankOptions &opts); + ~FbankComputer(); + + int32_t Dim() const { + return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); + } + + // if true, compute log_energy_pre_window but after dithering and dc removal + bool NeedRawLogEnergy() const { + return opts_.use_energy && opts_.raw_energy; + } + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + const FbankOptions &GetOptions() const { return opts_; } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the + signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. It should be pre-allocated. + */ + void Compute(float signal_raw_log_energy, + float vtln_warp, + std::vector *signal_frame, + float *feature); + + private: + const MelBanks *GetMelBanks(float vtln_warp); + + FbankOptions opts_; + float log_energy_floor_; + std::map mel_banks_; // float is VTLN coefficient. + Rfft rfft_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ diff --git a/runtime/engine/common/frontend/feature-functions.cc b/runtime/engine/common/frontend/feature-functions.cc new file mode 100644 index 00000000..178c711b --- /dev/null +++ b/runtime/engine/common/frontend/feature-functions.cc @@ -0,0 +1,49 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-functions.cc + +#include "frontend/feature-functions.h" + +#include +#include + +namespace knf { + +void ComputePowerSpectrum(std::vector *complex_fft) { + int32_t dim = complex_fft->size(); + + // now we have in complex_fft, first half of complex spectrum + // it's stored as [real0, realN/2, real1, im1, real2, im2, ...] + + float *p = complex_fft->data(); + int32_t half_dim = dim / 2; + float first_energy = p[0] * p[0]; + float last_energy = p[1] * p[1]; // handle this special case + + for (int32_t i = 1; i < half_dim; ++i) { + float real = p[i * 2]; + float im = p[i * 2 + 1]; + p[i] = real * real + im * im; + } + p[0] = first_energy; + p[half_dim] = last_energy; // Will actually never be used, and anyway + // if the signal has been bandlimited sensibly this should be zero. +} + +} // namespace knf diff --git a/runtime/engine/common/frontend/feature-functions.h b/runtime/engine/common/frontend/feature-functions.h new file mode 100644 index 00000000..852d0612 --- /dev/null +++ b/runtime/engine/common/frontend/feature-functions.h @@ -0,0 +1,38 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-functions.h +#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H +#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H + +#include +namespace knf { + +// ComputePowerSpectrum converts a complex FFT (as produced by the FFT +// functions in csrc/rfft.h), and converts it into +// a power spectrum. If the complex FFT is a vector of size n (representing +// half of the complex FFT of a real signal of size n, as described there), +// this function computes in the first (n/2) + 1 elements of it, the +// energies of the fft bins from zero to the Nyquist frequency. Contents of the +// remaining (n/2) - 1 elements are undefined at output. + +void ComputePowerSpectrum(std::vector *complex_fft); + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H diff --git a/runtime/engine/common/frontend/feature-window.cc b/runtime/engine/common/frontend/feature-window.cc new file mode 100644 index 00000000..43c736e0 --- /dev/null +++ b/runtime/engine/common/frontend/feature-window.cc @@ -0,0 +1,248 @@ +// kaldi-native-fbank/csrc/feature-window.cc +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-window.cc + +#include "frontend/feature-window.h" + +#include +#include +#include + +#ifndef M_2PI +#define M_2PI 6.283185307179586476925286766559005 +#endif + +namespace knf { + +std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) { + os << opts.ToString(); + return os; +} + +FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) + : window_(opts.WindowSize()) { + int32_t frame_length = opts.WindowSize(); + CHECK_GT(frame_length, 0); + + float *window_data = window_.data(); + + double a = M_2PI / (frame_length - 1); + for (int32_t i = 0; i < frame_length; i++) { + double i_fl = static_cast(i); + if (opts.window_type == "hanning") { + window_data[i] = 0.5 - 0.5 * cos(a * i_fl); + } else if (opts.window_type == "sine") { + // when you are checking ws wikipedia, please + // note that 0.5 * a = M_PI/(frame_length-1) + window_data[i] = sin(0.5 * a * i_fl); + } else if (opts.window_type == "hamming") { + window_data[i] = 0.54 - 0.46 * cos(a * i_fl); + } else if (opts.window_type == + "povey") { // like hamming but goes to zero at edges. + window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85); + } else if (opts.window_type == "rectangular") { + window_data[i] = 1.0; + } else if (opts.window_type == "blackman") { + window_data[i] = opts.blackman_coeff - 0.5 * cos(a * i_fl) + + (0.5 - opts.blackman_coeff) * cos(2 * a * i_fl); + } else { + LOG(FATAL) << "Invalid window type " << opts.window_type; + } + } +} + +void FeatureWindowFunction::Apply(float *wave) const { + int32_t window_size = window_.size(); + const float *p = window_.data(); + for (int32_t k = 0; k != window_size; ++k) { + wave[k] *= p[k]; + } +} + +int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts) { + int64_t frame_shift = opts.WindowShift(); + if (opts.snip_edges) { + return frame * frame_shift; + } else { + int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2, + beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2; + return beginning_of_frame; + } +} + +int32_t NumFrames(int64_t num_samples, + const FrameExtractionOptions &opts, + bool flush /*= true*/) { + int64_t frame_shift = opts.WindowShift(); + int64_t frame_length = opts.WindowSize(); + if (opts.snip_edges) { + // with --snip-edges=true (the default), we use a HTK-like approach to + // determining the number of frames-- all frames have to fit completely + // into + // the waveform, and the first frame begins at sample zero. + if (num_samples < frame_length) + return 0; + else + return (1 + ((num_samples - frame_length) / frame_shift)); + // You can understand the expression above as follows: 'num_samples - + // frame_length' is how much room we have to shift the frame within the + // waveform; 'frame_shift' is how much we shift it each time; and the + // ratio + // is how many times we can shift it (integer arithmetic rounds down). + } else { + // if --snip-edges=false, the number of frames is determined by rounding + // the + // (file-length / frame-shift) to the nearest integer. The point of + // this + // formula is to make the number of frames an obvious and predictable + // function of the frame shift and signal length, which makes many + // segmentation-related questions simpler. + // + // Because integer division in C++ rounds toward zero, we add (half the + // frame-shift minus epsilon) before dividing, to have the effect of + // rounding towards the closest integer. + int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift; + + if (flush) return num_frames; + + // note: 'end' always means the last plus one, i.e. one past the last. + int64_t end_sample_of_last_frame = + FirstSampleOfFrame(num_frames - 1, opts) + frame_length; + + // the following code is optimized more for clarity than efficiency. + // If flush == false, we can't output frames that extend past the end + // of the signal. + while (num_frames > 0 && end_sample_of_last_frame > num_samples) { + num_frames--; + end_sample_of_last_frame -= frame_shift; + } + return num_frames; + } +} + +void ExtractWindow(int64_t sample_offset, + const std::vector &wave, + int32_t f, + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + std::vector *window, + float *log_energy_pre_window /*= nullptr*/) { + CHECK(sample_offset >= 0 && wave.size() != 0); + + int32_t frame_length = opts.WindowSize(); + int32_t frame_length_padded = opts.PaddedWindowSize(); + + int64_t num_samples = sample_offset + wave.size(); + int64_t start_sample = FirstSampleOfFrame(f, opts); + int64_t end_sample = start_sample + frame_length; + + if (opts.snip_edges) { + CHECK(start_sample >= sample_offset && end_sample <= num_samples); + } else { + CHECK(sample_offset == 0 || start_sample >= sample_offset); + } + + if (window->size() != frame_length_padded) { + window->resize(frame_length_padded); + } + + // wave_start and wave_end are start and end indexes into 'wave', for the + // piece of wave that we're trying to extract. + int32_t wave_start = int32_t(start_sample - sample_offset); + int32_t wave_end = wave_start + frame_length; + + if (wave_start >= 0 && wave_end <= wave.size()) { + // the normal case-- no edge effects to consider. + std::copy(wave.begin() + wave_start, + wave.begin() + wave_start + frame_length, + window->data()); + } else { + // Deal with any end effects by reflection, if needed. This code will + // only + // be reached for about two frames per utterance, so we don't concern + // ourselves excessively with efficiency. + int32_t wave_dim = wave.size(); + for (int32_t s = 0; s < frame_length; ++s) { + int32_t s_in_wave = s + wave_start; + while (s_in_wave < 0 || s_in_wave >= wave_dim) { + // reflect around the beginning or end of the wave. + // e.g. -1 -> 0, -2 -> 1. + // dim -> dim - 1, dim + 1 -> dim - 2. + // the code supports repeated reflections, although this + // would only be needed in pathological cases. + if (s_in_wave < 0) + s_in_wave = -s_in_wave - 1; + else + s_in_wave = 2 * wave_dim - 1 - s_in_wave; + } + (*window)[s] = wave[s_in_wave]; + } + } + + ProcessWindow(opts, window_function, window->data(), log_energy_pre_window); +} + +static void RemoveDcOffset(float *d, int32_t n) { + float sum = 0; + for (int32_t i = 0; i != n; ++i) { + sum += d[i]; + } + + float mean = sum / n; + + for (int32_t i = 0; i != n; ++i) { + d[i] -= mean; + } +} + +float InnerProduct(const float *a, const float *b, int32_t n) { + float sum = 0; + for (int32_t i = 0; i != n; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +static void Preemphasize(float *d, int32_t n, float preemph_coeff) { + if (preemph_coeff == 0.0) { + return; + } + + CHECK(preemph_coeff >= 0.0 && preemph_coeff <= 1.0); + + for (int32_t i = n - 1; i > 0; --i) { + d[i] -= preemph_coeff * d[i - 1]; + } + d[0] -= preemph_coeff * d[0]; +} + +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + float *window, + float *log_energy_pre_window /*= nullptr*/) { + int32_t frame_length = opts.WindowSize(); + + // TODO(fangjun): Remove dither + CHECK_EQ(opts.dither, 0); + + if (opts.remove_dc_offset) { + RemoveDcOffset(window, frame_length); + } + + if (log_energy_pre_window != NULL) { + float energy = + std::max(InnerProduct(window, window, frame_length), + std::numeric_limits::epsilon()); + *log_energy_pre_window = std::log(energy); + } + + if (opts.preemph_coeff != 0.0) { + Preemphasize(window, frame_length, opts.preemph_coeff); + } + + window_function.Apply(window); +} + +} // namespace knf diff --git a/runtime/engine/common/frontend/feature-window.h b/runtime/engine/common/frontend/feature-window.h new file mode 100644 index 00000000..8c86bf05 --- /dev/null +++ b/runtime/engine/common/frontend/feature-window.h @@ -0,0 +1,183 @@ +// kaldi-native-fbank/csrc/feature-window.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-window.h + +#ifndef KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ +#define KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ + +#include +#include +#include + +#include "base/log.h" + +namespace knf { + +inline int32_t RoundUpToNearestPowerOfTwo(int32_t n) { + // copied from kaldi/src/base/kaldi-math.cc + CHECK_GT(n, 0); + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n + 1; +} + +struct FrameExtractionOptions { + float samp_freq = 16000; + float frame_shift_ms = 10.0f; // in milliseconds. + float frame_length_ms = 25.0f; // in milliseconds. + float dither = 1.0f; // Amount of dithering, 0.0 means no dither. + float preemph_coeff = 0.97f; // Preemphasis coefficient. + bool remove_dc_offset = true; // Subtract mean of wave before FFT. + std::string window_type = "povey"; // e.g. Hamming window + // May be "hamming", "rectangular", "povey", "hanning", "sine", "blackman" + // "povey" is a window I made to be similar to Hamming but to go to zero at + // the edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) I just don't think + // the + // Hamming window makes sense as a windowing function. + bool round_to_power_of_two = true; + float blackman_coeff = 0.42f; + bool snip_edges = true; + // bool allow_downsample = false; + // bool allow_upsample = false; + + // Used for streaming feature extraction. It indicates the number + // of feature frames to keep in the recycling vector. -1 means to + // keep all feature frames. + int32_t max_feature_vectors = -1; + + int32_t WindowShift() const { + return static_cast(samp_freq * 0.001f * frame_shift_ms); + } + int32_t WindowSize() const { + return static_cast(samp_freq * 0.001f * frame_length_ms); + } + int32_t PaddedWindowSize() const { + return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) + : WindowSize()); + } + std::string ToString() const { + std::ostringstream os; +#define KNF_PRINT(x) os << #x << ": " << x << "\n" + KNF_PRINT(samp_freq); + KNF_PRINT(frame_shift_ms); + KNF_PRINT(frame_length_ms); + KNF_PRINT(dither); + KNF_PRINT(preemph_coeff); + KNF_PRINT(remove_dc_offset); + KNF_PRINT(window_type); + KNF_PRINT(round_to_power_of_two); + KNF_PRINT(blackman_coeff); + KNF_PRINT(snip_edges); + // KNF_PRINT(allow_downsample); + // KNF_PRINT(allow_upsample); + KNF_PRINT(max_feature_vectors); +#undef KNF_PRINT + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts); + +class FeatureWindowFunction { + public: + FeatureWindowFunction() = default; + explicit FeatureWindowFunction(const FrameExtractionOptions &opts); + /** + * @param wave Pointer to a 1-D array of shape [window_size]. + * It is modified in-place: wave[i] = wave[i] * window_[i]. + * @param + */ + void Apply(float *wave) const; + + private: + std::vector window_; // of size opts.WindowSize() +}; + +int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts); + +/** + This function returns the number of frames that we can extract from a wave + file with the given number of samples in it (assumed to have the same + sampling rate as specified in 'opts'). + + @param [in] num_samples The number of samples in the wave file. + @param [in] opts The frame-extraction options class + + @param [in] flush True if we are asserting that this number of samples + is 'all there is', false if we expecting more data to possibly come in. This + only makes a difference to the answer + if opts.snips_edges== false. For offline feature extraction you always want + flush == true. In an online-decoding context, once you know (or decide) that + no more data is coming in, you'd call it with flush == true at the end to + flush out any remaining data. +*/ +int32_t NumFrames(int64_t num_samples, + const FrameExtractionOptions &opts, + bool flush = true); + +/* + ExtractWindow() extracts a windowed frame of waveform (possibly with a + power-of-two, padded size, depending on the config), including all the + processing done by ProcessWindow(). + + @param [in] sample_offset If 'wave' is not the entire waveform, but + part of it to the left has been discarded, then the + number of samples prior to 'wave' that we have + already discarded. Set this to zero if you are + processing the entire waveform in one piece, or + if you get 'no matching function' compilation + errors when updating the code. + @param [in] wave The waveform + @param [in] f The frame index to be extracted, with + 0 <= f < NumFrames(sample_offset + wave.Dim(), opts, true) + @param [in] opts The options class to be used + @param [in] window_function The windowing function, as derived from the + options class. + @param [out] window The windowed, possibly-padded waveform to be + extracted. Will be resized as needed. + @param [out] log_energy_pre_window If non-NULL, the log-energy of + the signal prior to pre-emphasis and multiplying by + the windowing function will be written to here. +*/ +void ExtractWindow(int64_t sample_offset, + const std::vector &wave, + int32_t f, + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + std::vector *window, + float *log_energy_pre_window = nullptr); + +/** + This function does all the windowing steps after actually + extracting the windowed signal: depending on the + configuration, it does dithering, dc offset removal, + preemphasis, and multiplication by the windowing function. + @param [in] opts The options class to be used + @param [in] window_function The windowing function-- should have + been initialized using 'opts'. + @param [in,out] window A vector of size opts.WindowSize(). Note: + it will typically be a sub-vector of a larger vector of size + opts.PaddedWindowSize(), with the remaining samples zero, + as the FFT code is more efficient if it operates on data with + power-of-two size. + @param [out] log_energy_pre_window If non-NULL, then after dithering and + DC offset removal, this function will write to this pointer the log of + the total energy (i.e. sum-squared) of the frame. + */ +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + float *window, + float *log_energy_pre_window = nullptr); + +// Compute the inner product of two vectors +float InnerProduct(const float *a, const float *b, int32_t n); + +} // namespace knf + +#endif // KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/runtime/engine/common/frontend/feature_cache.cc similarity index 50% rename from speechx/speechx/frontend/audio/feature_cache.cc rename to runtime/engine/common/frontend/feature_cache.cc index 5110d704..650c84cc 100644 --- a/speechx/speechx/frontend/audio/feature_cache.cc +++ b/runtime/engine/common/frontend/feature_cache.cc @@ -12,94 +12,72 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/feature_cache.h" +#include "frontend/feature_cache.h" namespace ppspeech { using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; using std::unique_ptr; using std::vector; -FeatureCache::FeatureCache(FeatureCacheOptions opts, +FeatureCache::FeatureCache(size_t max_size, unique_ptr base_extractor) { - max_size_ = opts.max_size; - timeout_ = opts.timeout; // ms + max_size_ = max_size; base_extractor_ = std::move(base_extractor); dim_ = base_extractor_->Dim(); } -void FeatureCache::Accept(const kaldi::VectorBase& inputs) { +void FeatureCache::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); - - // feed current data - bool result = false; - do { - result = Compute(); - } while (result); } // pop feature chunk -bool FeatureCache::Read(kaldi::Vector* feats) { +bool FeatureCache::Read(std::vector* feats) { kaldi::Timer timer; - std::unique_lock lock(mutex_); - while (cache_.empty() && base_extractor_->IsFinished() == false) { - // todo refactor: wait - // ready_read_condition_.wait(lock); - int32 elapsed = static_cast(timer.Elapsed() * 1000); // ms - if (elapsed > timeout_) { - return false; - } - usleep(100); // sleep 0.1 ms + // feed current data + if (cache_.empty()) { + bool result = false; + do { + result = Compute(); + } while (result); } + if (cache_.empty()) return false; // read from cache - feats->Resize(cache_.front().Dim()); - feats->CopyFromVec(cache_.front()); + *feats = cache_.front(); cache_.pop(); - ready_feed_condition_.notify_one(); - VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; + VLOG(2) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; + VLOG(1) << "FeatureCache::size : " << cache_.size(); return true; } // read all data from base_feature_extractor_ into cache_ bool FeatureCache::Compute() { // compute and feed - Vector feature; + vector feature; bool result = base_extractor_->Read(&feature); - if (result == false || feature.Dim() == 0) return false; + if (result == false || feature.size() == 0) return false; kaldi::Timer timer; - int32 num_chunk = feature.Dim() / dim_; - nframe_ += num_chunk; + int32 num_chunk = feature.size() / dim_; VLOG(3) << "nframe computed: " << nframe_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { int32 start = chunk_idx * dim_; - Vector feature_chunk(dim_); - SubVector tmp(feature.Data() + start, dim_); - feature_chunk.CopyFromVec(tmp); - - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { - // cache full, wait - ready_feed_condition_.wait(lock); - } - + vector feature_chunk(feature.data() + start, + feature.data() + start + dim_); // feed cache cache_.push(feature_chunk); - ready_read_condition_.notify_one(); + ++nframe_; } - VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " + VLOG(2) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " << num_chunk << " feats."; return true; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/runtime/engine/common/frontend/feature_cache.h similarity index 54% rename from speechx/speechx/frontend/audio/feature_cache.h rename to runtime/engine/common/frontend/feature_cache.h index a4ebd604..549a5724 100644 --- a/speechx/speechx/frontend/audio/feature_cache.h +++ b/runtime/engine/common/frontend/feature_cache.h @@ -15,67 +15,51 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { -struct FeatureCacheOptions { - int32 max_size; - int32 timeout; // ms - FeatureCacheOptions() : max_size(kint16max), timeout(1) {} -}; - class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - FeatureCacheOptions opts, + size_t max_size = kint16max, std::unique_ptr base_extractor = NULL); // Feed feats or waves - virtual void Accept(const kaldi::VectorBase& inputs); + virtual void Accept(const std::vector& inputs); // feats size = num_frames * feat_dim - virtual bool Read(kaldi::Vector* feats); + virtual bool Read(std::vector* feats); // feat dim virtual size_t Dim() const { return dim_; } virtual void SetFinished() { - LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); + std::unique_lock lock(mutex_); base_extractor_->SetFinished(); - - // read the last chunk data - Compute(); - // ready_feed_condition_.notify_one(); - LOG(INFO) << "compute last feats done."; } - virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual bool IsFinished() const { + return base_extractor_->IsFinished() && cache_.empty(); + } void Reset() override { - std::queue> empty; + std::queue> empty; + VLOG(1) << "feature cache size: " << cache_.size(); std::swap(cache_, empty); nframe_ = 0; base_extractor_->Reset(); - VLOG(3) << "feature cache reset: cache size: " << cache_.size(); } private: bool Compute(); int32 dim_; - size_t max_size_; // cache capacity - int32 frame_chunk_size_; // window - int32 frame_chunk_stride_; // stride + size_t max_size_; // cache capacity std::unique_ptr base_extractor_; - kaldi::int32 timeout_; // ms - kaldi::Vector remained_feature_; - std::queue> cache_; // feature cache + std::queue> cache_; // feature cache std::mutex mutex_; - std::condition_variable ready_feed_condition_; - std::condition_variable ready_read_condition_; int32 nframe_; // num of feature computed DISALLOW_COPY_AND_ASSIGN(FeatureCache); diff --git a/speechx/speechx/frontend/audio/feature_common.h b/runtime/engine/common/frontend/feature_common.h similarity index 74% rename from speechx/speechx/frontend/audio/feature_common.h rename to runtime/engine/common/frontend/feature_common.h index bad705c9..fcc9100c 100644 --- a/speechx/speechx/frontend/audio/feature_common.h +++ b/runtime/engine/common/frontend/feature_common.h @@ -14,8 +14,8 @@ #pragma once +#include "frontend/feature-window.h" #include "frontend_itf.h" -#include "kaldi/feat/feature-window.h" namespace ppspeech { @@ -25,8 +25,8 @@ class StreamingFeatureTpl : public FrontendInterface { typedef typename F::Options Options; StreamingFeatureTpl(const Options& opts, std::unique_ptr base_extractor); - virtual void Accept(const kaldi::VectorBase& waves); - virtual bool Read(kaldi::Vector* feats); + virtual void Accept(const std::vector& waves); + virtual bool Read(std::vector* feats); // the dim_ is the dim of single frame feature virtual size_t Dim() const { return computer_.Dim(); } @@ -37,19 +37,19 @@ class StreamingFeatureTpl : public FrontendInterface { virtual void Reset() { base_extractor_->Reset(); - remained_wav_.Resize(0); + remained_wav_.resize(0); } private: - bool Compute(const kaldi::Vector& waves, - kaldi::Vector* feats); + bool Compute(const std::vector& waves, + std::vector* feats); Options opts_; std::unique_ptr base_extractor_; - kaldi::FeatureWindowFunction window_function_; - kaldi::Vector remained_wav_; + knf::FeatureWindowFunction window_function_; + std::vector remained_wav_; F computer_; }; } // namespace ppspeech -#include "frontend/audio/feature_common_inl.h" +#include "frontend/feature_common_inl.h" diff --git a/runtime/engine/common/frontend/feature_common_inl.h b/runtime/engine/common/frontend/feature_common_inl.h new file mode 100644 index 00000000..ac239974 --- /dev/null +++ b/runtime/engine/common/frontend/feature_common_inl.h @@ -0,0 +1,102 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +namespace ppspeech { + +template +StreamingFeatureTpl::StreamingFeatureTpl( + const Options& opts, std::unique_ptr base_extractor) + : opts_(opts), computer_(opts), window_function_(opts.frame_opts) { + base_extractor_ = std::move(base_extractor); +} + +template +void StreamingFeatureTpl::Accept( + const std::vector& waves) { + base_extractor_->Accept(waves); +} + +template +bool StreamingFeatureTpl::Read(std::vector* feats) { + std::vector wav(base_extractor_->Dim()); + bool flag = base_extractor_->Read(&wav); + if (flag == false || wav.size() == 0) return false; + + // append remaned waves + int32 wav_len = wav.size(); + int32 left_len = remained_wav_.size(); + std::vector waves(left_len + wav_len); + std::memcpy(waves.data(), + remained_wav_.data(), + left_len * sizeof(kaldi::BaseFloat)); + std::memcpy(waves.data() + left_len, + wav.data(), + wav_len * sizeof(kaldi::BaseFloat)); + + // compute speech feature + Compute(waves, feats); + + // cache remaned waves + knf::FrameExtractionOptions frame_opts = computer_.GetFrameOptions(); + int32 num_frames = knf::NumFrames(waves.size(), frame_opts); + int32 frame_shift = frame_opts.WindowShift(); + int32 left_samples = waves.size() - frame_shift * num_frames; + remained_wav_.resize(left_samples); + std::memcpy(remained_wav_.data(), + waves.data() + frame_shift * num_frames, + left_samples * sizeof(BaseFloat)); + return true; +} + +// Compute feat +template +bool StreamingFeatureTpl::Compute(const std::vector& waves, + std::vector* feats) { + const knf::FrameExtractionOptions& frame_opts = computer_.GetFrameOptions(); + int32 num_samples = waves.size(); + int32 frame_length = frame_opts.WindowSize(); + int32 sample_rate = frame_opts.samp_freq; + if (num_samples < frame_length) { + return true; + } + + int32 num_frames = knf::NumFrames(num_samples, frame_opts); + feats->resize(num_frames * Dim()); + + std::vector window; + bool need_raw_log_energy = computer_.NeedRawLogEnergy(); + for (int32 frame = 0; frame < num_frames; frame++) { + std::fill(window.begin(), window.end(), 0); + kaldi::BaseFloat raw_log_energy = 0.0; + kaldi::BaseFloat vtln_warp = 1.0; + knf::ExtractWindow(0, + waves, + frame, + frame_opts, + window_function_, + &window, + need_raw_log_energy ? &raw_log_energy : NULL); + + std::vector this_feature(computer_.Dim()); + computer_.Compute( + raw_log_energy, vtln_warp, &window, this_feature.data()); + std::memcpy(feats->data() + frame * Dim(), + this_feature.data(), + sizeof(BaseFloat) * Dim()); + } + return true; +} + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/runtime/engine/common/frontend/feature_pipeline.cc similarity index 61% rename from speechx/speechx/frontend/audio/feature_pipeline.cc rename to runtime/engine/common/frontend/feature_pipeline.cc index 2931b96b..7d662bc1 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.cc +++ b/runtime/engine/common/frontend/feature_pipeline.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/feature_pipeline.h" +#include "frontend/feature_pipeline.h" namespace ppspeech { @@ -21,24 +21,25 @@ using std::unique_ptr; FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opts) { unique_ptr data_source( - new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); + new ppspeech::AudioCache(1000 * kint16max, false)); unique_ptr base_feature; - if (opts.use_fbank) { - base_feature.reset( - new ppspeech::Fbank(opts.fbank_opts, std::move(data_source))); - } else { - base_feature.reset(new ppspeech::LinearSpectrogram( - opts.linear_spectrogram_opts, std::move(data_source))); - } + base_feature.reset( + new ppspeech::Fbank(opts.fbank_opts, std::move(data_source))); - CHECK_NE(opts.cmvn_file, ""); - unique_ptr cmvn( - new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); + // CHECK_NE(opts.cmvn_file, ""); + unique_ptr cache; + if (opts.cmvn_file != ""){ + unique_ptr cmvn( + new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); - unique_ptr cache( - new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); + cache.reset( + new ppspeech::FeatureCache(kint16max, std::move(cmvn))); + } else { + cache.reset( + new ppspeech::FeatureCache(kint16max, std::move(base_feature))); + } base_extractor_.reset( new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/runtime/engine/common/frontend/feature_pipeline.h similarity index 67% rename from speechx/speechx/frontend/audio/feature_pipeline.h rename to runtime/engine/common/frontend/feature_pipeline.h index e83a3f31..7509814f 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.h +++ b/runtime/engine/common/frontend/feature_pipeline.h @@ -16,17 +16,15 @@ #pragma once -#include "frontend/audio/assembler.h" -#include "frontend/audio/audio_cache.h" -#include "frontend/audio/data_cache.h" -#include "frontend/audio/fbank.h" -#include "frontend/audio/feature_cache.h" -#include "frontend/audio/frontend_itf.h" -#include "frontend/audio/linear_spectrogram.h" -#include "frontend/audio/normalizer.h" +#include "frontend/assembler.h" +#include "frontend/audio_cache.h" +#include "frontend/cmvn.h" +#include "frontend/data_cache.h" +#include "frontend/fbank.h" +#include "frontend/feature_cache.h" +#include "frontend/frontend_itf.h" // feature -DECLARE_bool(use_fbank); DECLARE_bool(fill_zero); DECLARE_int32(num_bins); DECLARE_string(cmvn_file); @@ -40,11 +38,7 @@ namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file{}; - bool to_float32{false}; // true, only for linear feature - bool use_fbank{true}; - LinearSpectrogramOptions linear_spectrogram_opts{}; - kaldi::FbankOptions fbank_opts{}; - FeatureCacheOptions feature_cache_opts{}; + knf::FbankOptions fbank_opts{}; AssemblerOptions assembler_opts{}; static FeaturePipelineOptions InitFromFlags() { @@ -53,30 +47,17 @@ struct FeaturePipelineOptions { LOG(INFO) << "cmvn file: " << opts.cmvn_file; // frame options - kaldi::FrameExtractionOptions frame_opts; + knf::FrameExtractionOptions frame_opts; frame_opts.dither = 0.0; LOG(INFO) << "dither: " << frame_opts.dither; frame_opts.frame_shift_ms = 10; LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms; - opts.use_fbank = FLAGS_use_fbank; - LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear"); - if (opts.use_fbank) { - opts.to_float32 = false; - frame_opts.window_type = "povey"; - frame_opts.frame_length_ms = 25; - opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; - LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins; - - opts.fbank_opts.frame_opts = frame_opts; - } else { - opts.to_float32 = true; - frame_opts.remove_dc_offset = false; - frame_opts.frame_length_ms = 20; - frame_opts.window_type = "hanning"; - frame_opts.preemph_coeff = 0.0; - - opts.linear_spectrogram_opts.frame_opts = frame_opts; - } + frame_opts.window_type = "povey"; + frame_opts.frame_length_ms = 25; + opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins; + + opts.fbank_opts.frame_opts = frame_opts; LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms; // assembler opts @@ -100,10 +81,10 @@ struct FeaturePipelineOptions { class FeaturePipeline : public FrontendInterface { public: explicit FeaturePipeline(const FeaturePipelineOptions& opts); - virtual void Accept(const kaldi::VectorBase& waves) { + virtual void Accept(const std::vector& waves) { base_extractor_->Accept(waves); } - virtual bool Read(kaldi::Vector* feats) { + virtual bool Read(std::vector* feats) { return base_extractor_->Read(feats); } virtual size_t Dim() const { return base_extractor_->Dim(); } diff --git a/runtime/engine/common/frontend/fftsg.c b/runtime/engine/common/frontend/fftsg.c new file mode 100644 index 00000000..30b81604 --- /dev/null +++ b/runtime/engine/common/frontend/fftsg.c @@ -0,0 +1,3271 @@ +/* This file is copied from + * https://www.kurims.kyoto-u.ac.jp/~ooura/fft.html + */ +/* +Fast Fourier/Cosine/Sine Transform + dimension :one + data length :power of 2 + decimation :frequency + radix :split-radix + data :inplace + table :use +functions + cdft: Complex Discrete Fourier Transform + rdft: Real Discrete Fourier Transform + ddct: Discrete Cosine Transform + ddst: Discrete Sine Transform + dfct: Cosine Transform of RDFT (Real Symmetric DFT) + dfst: Sine Transform of RDFT (Real Anti-symmetric DFT) +function prototypes + void cdft(int, int, double *, int *, double *); + void rdft(int, int, double *, int *, double *); + void ddct(int, int, double *, int *, double *); + void ddst(int, int, double *, int *, double *); + void dfct(int, double *, double *, int *, double *); + void dfst(int, double *, double *, int *, double *); +macro definitions + USE_CDFT_PTHREADS : default=not defined + CDFT_THREADS_BEGIN_N : must be >= 512, default=8192 + CDFT_4THREADS_BEGIN_N : must be >= 512, default=65536 + USE_CDFT_WINTHREADS : default=not defined + CDFT_THREADS_BEGIN_N : must be >= 512, default=32768 + CDFT_4THREADS_BEGIN_N : must be >= 512, default=524288 + + +-------- Complex DFT (Discrete Fourier Transform) -------- + [definition] + + X[k] = sum_j=0^n-1 x[j]*exp(2*pi*i*j*k/n), 0<=k + X[k] = sum_j=0^n-1 x[j]*exp(-2*pi*i*j*k/n), 0<=k + ip[0] = 0; // first time only + cdft(2*n, 1, a, ip, w); + + ip[0] = 0; // first time only + cdft(2*n, -1, a, ip, w); + [parameters] + 2*n :data length (int) + n >= 1, n = power of 2 + a[0...2*n-1] :input/output data (double *) + input data + a[2*j] = Re(x[j]), + a[2*j+1] = Im(x[j]), 0<=j= 2+sqrt(n) + strictly, + length of ip >= + 2+(1<<(int)(log(n+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n/2-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + cdft(2*n, -1, a, ip, w); + is + cdft(2*n, 1, a, ip, w); + for (j = 0; j <= 2 * n - 1; j++) { + a[j] *= 1.0 / n; + } + . + + +-------- Real DFT / Inverse of Real DFT -------- + [definition] + RDFT + R[k] = sum_j=0^n-1 a[j]*cos(2*pi*j*k/n), 0<=k<=n/2 + I[k] = sum_j=0^n-1 a[j]*sin(2*pi*j*k/n), 0 IRDFT (excluding scale) + a[k] = (R[0] + R[n/2]*cos(pi*k))/2 + + sum_j=1^n/2-1 R[j]*cos(2*pi*j*k/n) + + sum_j=1^n/2-1 I[j]*sin(2*pi*j*k/n), 0<=k + ip[0] = 0; // first time only + rdft(n, 1, a, ip, w); + + ip[0] = 0; // first time only + rdft(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + + output data + a[2*k] = R[k], 0<=k + input data + a[2*j] = R[j], 0<=j= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n/2-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + rdft(n, 1, a, ip, w); + is + rdft(n, -1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- DCT (Discrete Cosine Transform) / Inverse of DCT -------- + [definition] + IDCT (excluding scale) + C[k] = sum_j=0^n-1 a[j]*cos(pi*j*(k+1/2)/n), 0<=k DCT + C[k] = sum_j=0^n-1 a[j]*cos(pi*(j+1/2)*k/n), 0<=k + ip[0] = 0; // first time only + ddct(n, 1, a, ip, w); + + ip[0] = 0; // first time only + ddct(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + output data + a[k] = C[k], 0<=k= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/4-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + ddct(n, -1, a, ip, w); + is + a[0] *= 0.5; + ddct(n, 1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- DST (Discrete Sine Transform) / Inverse of DST -------- + [definition] + IDST (excluding scale) + S[k] = sum_j=1^n A[j]*sin(pi*j*(k+1/2)/n), 0<=k DST + S[k] = sum_j=0^n-1 a[j]*sin(pi*(j+1/2)*k/n), 0 + ip[0] = 0; // first time only + ddst(n, 1, a, ip, w); + + ip[0] = 0; // first time only + ddst(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + + input data + a[j] = A[j], 0 + output data + a[k] = S[k], 0= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/4-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + ddst(n, -1, a, ip, w); + is + a[0] *= 0.5; + ddst(n, 1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- Cosine Transform of RDFT (Real Symmetric DFT) -------- + [definition] + C[k] = sum_j=0^n a[j]*cos(pi*j*k/n), 0<=k<=n + [usage] + ip[0] = 0; // first time only + dfct(n, a, t, ip, w); + [parameters] + n :data length - 1 (int) + n >= 2, n = power of 2 + a[0...n] :input/output data (double *) + output data + a[k] = C[k], 0<=k<=n + t[0...n/2] :work area (double *) + ip[0...*] :work area for bit reversal (int *) + length of ip >= 2+sqrt(n/4) + strictly, + length of ip >= + 2+(1<<(int)(log(n/4+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/8-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + a[0] *= 0.5; + a[n] *= 0.5; + dfct(n, a, t, ip, w); + is + a[0] *= 0.5; + a[n] *= 0.5; + dfct(n, a, t, ip, w); + for (j = 0; j <= n; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- Sine Transform of RDFT (Real Anti-symmetric DFT) -------- + [definition] + S[k] = sum_j=1^n-1 a[j]*sin(pi*j*k/n), 0= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + output data + a[k] = S[k], 0= 2+sqrt(n/4) + strictly, + length of ip >= + 2+(1<<(int)(log(n/4+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/8-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + dfst(n, a, t, ip, w); + is + dfst(n, a, t, ip, w); + for (j = 1; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +Appendix : + The cos/sin table is recalculated when the larger table required. + w[] and ip[] are compatible with all routines. +*/ + + +void cdft(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + int nw; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + if (isgn >= 0) { + cftfsub(n, a, ip, nw, w); + } else { + cftbsub(n, a, ip, nw, w); + } +} + + +void rdft(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + int nw, nc; + double xi; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 2)) { + nc = n >> 2; + makect(nc, ip, w + nw); + } + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xi = a[0] - a[1]; + a[0] += a[1]; + a[1] = xi; + } else { + a[1] = 0.5 * (a[0] - a[1]); + a[0] -= a[1]; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } +} + + +void ddct(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + void dctsub(int n, double *a, int nc, double *c); + int j, nw, nc; + double xr; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > nc) { + nc = n; + makect(nc, ip, w + nw); + } + if (isgn < 0) { + xr = a[n - 1]; + for (j = n - 2; j >= 2; j -= 2) { + a[j + 1] = a[j] - a[j - 1]; + a[j] += a[j - 1]; + } + a[1] = a[0] - xr; + a[0] += xr; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } + dctsub(n, a, nc, w + nw); + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xr = a[0] - a[1]; + a[0] += a[1]; + for (j = 2; j < n; j += 2) { + a[j - 1] = a[j] - a[j + 1]; + a[j] += a[j + 1]; + } + a[n - 1] = xr; + } +} + + +void ddst(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + void dstsub(int n, double *a, int nc, double *c); + int j, nw, nc; + double xr; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > nc) { + nc = n; + makect(nc, ip, w + nw); + } + if (isgn < 0) { + xr = a[n - 1]; + for (j = n - 2; j >= 2; j -= 2) { + a[j + 1] = -a[j] - a[j - 1]; + a[j] -= a[j - 1]; + } + a[1] = a[0] + xr; + a[0] -= xr; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } + dstsub(n, a, nc, w + nw); + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xr = a[0] - a[1]; + a[0] += a[1]; + for (j = 2; j < n; j += 2) { + a[j - 1] = -a[j] - a[j + 1]; + a[j] -= a[j + 1]; + } + a[n - 1] = -xr; + } +} + + +void dfct(int n, double *a, double *t, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void dctsub(int n, double *a, int nc, double *c); + int j, k, l, m, mh, nw, nc; + double xr, xi, yr, yi; + + nw = ip[0]; + if (n > (nw << 3)) { + nw = n >> 3; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 1)) { + nc = n >> 1; + makect(nc, ip, w + nw); + } + m = n >> 1; + yi = a[m]; + xi = a[0] + a[n]; + a[0] -= a[n]; + t[0] = xi - yi; + t[m] = xi + yi; + if (n > 2) { + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + xr = a[j] - a[n - j]; + xi = a[j] + a[n - j]; + yr = a[k] - a[n - k]; + yi = a[k] + a[n - k]; + a[j] = xr; + a[k] = yr; + t[j] = xi - yi; + t[k] = xi + yi; + } + t[mh] = a[mh] + a[n - mh]; + a[mh] -= a[n - mh]; + dctsub(m, a, nc, w + nw); + if (m > 4) { + cftfsub(m, a, ip, nw, w); + rftfsub(m, a, nc, w + nw); + } else if (m == 4) { + cftfsub(m, a, ip, nw, w); + } + a[n - 1] = a[0] - a[1]; + a[1] = a[0] + a[1]; + for (j = m - 2; j >= 2; j -= 2) { + a[2 * j + 1] = a[j] + a[j + 1]; + a[2 * j - 1] = a[j] - a[j + 1]; + } + l = 2; + m = mh; + while (m >= 2) { + dctsub(m, t, nc, w + nw); + if (m > 4) { + cftfsub(m, t, ip, nw, w); + rftfsub(m, t, nc, w + nw); + } else if (m == 4) { + cftfsub(m, t, ip, nw, w); + } + a[n - l] = t[0] - t[1]; + a[l] = t[0] + t[1]; + k = 0; + for (j = 2; j < m; j += 2) { + k += l << 2; + a[k - l] = t[j] - t[j + 1]; + a[k + l] = t[j] + t[j + 1]; + } + l <<= 1; + mh = m >> 1; + for (j = 0; j < mh; j++) { + k = m - j; + t[j] = t[m + k] - t[m + j]; + t[k] = t[m + k] + t[m + j]; + } + t[mh] = t[m + mh]; + m = mh; + } + a[l] = t[0]; + a[n] = t[2] - t[1]; + a[0] = t[2] + t[1]; + } else { + a[1] = a[0]; + a[2] = t[0]; + a[0] = t[1]; + } +} + + +void dfst(int n, double *a, double *t, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void dstsub(int n, double *a, int nc, double *c); + int j, k, l, m, mh, nw, nc; + double xr, xi, yr, yi; + + nw = ip[0]; + if (n > (nw << 3)) { + nw = n >> 3; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 1)) { + nc = n >> 1; + makect(nc, ip, w + nw); + } + if (n > 2) { + m = n >> 1; + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + xr = a[j] + a[n - j]; + xi = a[j] - a[n - j]; + yr = a[k] + a[n - k]; + yi = a[k] - a[n - k]; + a[j] = xr; + a[k] = yr; + t[j] = xi + yi; + t[k] = xi - yi; + } + t[0] = a[mh] - a[n - mh]; + a[mh] += a[n - mh]; + a[0] = a[m]; + dstsub(m, a, nc, w + nw); + if (m > 4) { + cftfsub(m, a, ip, nw, w); + rftfsub(m, a, nc, w + nw); + } else if (m == 4) { + cftfsub(m, a, ip, nw, w); + } + a[n - 1] = a[1] - a[0]; + a[1] = a[0] + a[1]; + for (j = m - 2; j >= 2; j -= 2) { + a[2 * j + 1] = a[j] - a[j + 1]; + a[2 * j - 1] = -a[j] - a[j + 1]; + } + l = 2; + m = mh; + while (m >= 2) { + dstsub(m, t, nc, w + nw); + if (m > 4) { + cftfsub(m, t, ip, nw, w); + rftfsub(m, t, nc, w + nw); + } else if (m == 4) { + cftfsub(m, t, ip, nw, w); + } + a[n - l] = t[1] - t[0]; + a[l] = t[0] + t[1]; + k = 0; + for (j = 2; j < m; j += 2) { + k += l << 2; + a[k - l] = -t[j] - t[j + 1]; + a[k + l] = t[j] - t[j + 1]; + } + l <<= 1; + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + t[j] = t[m + k] + t[m + j]; + t[k] = t[m + k] - t[m + j]; + } + t[0] = t[m + mh]; + m = mh; + } + a[l] = t[0]; + } + a[0] = 0; +} + + +/* -------- initializing routines -------- */ + + +#include + +void makewt(int nw, int *ip, double *w) { + void makeipt(int nw, int *ip); + int j, nwh, nw0, nw1; + double delta, wn4r, wk1r, wk1i, wk3r, wk3i; + + ip[0] = nw; + ip[1] = 1; + if (nw > 2) { + nwh = nw >> 1; + delta = atan(1.0) / nwh; + wn4r = cos(delta * nwh); + w[0] = 1; + w[1] = wn4r; + if (nwh == 4) { + w[2] = cos(delta * 2); + w[3] = sin(delta * 2); + } else if (nwh > 4) { + makeipt(nw, ip); + w[2] = 0.5 / cos(delta * 2); + w[3] = 0.5 / cos(delta * 6); + for (j = 4; j < nwh; j += 4) { + w[j] = cos(delta * j); + w[j + 1] = sin(delta * j); + w[j + 2] = cos(3 * delta * j); + w[j + 3] = -sin(3 * delta * j); + } + } + nw0 = 0; + while (nwh > 2) { + nw1 = nw0 + nwh; + nwh >>= 1; + w[nw1] = 1; + w[nw1 + 1] = wn4r; + if (nwh == 4) { + wk1r = w[nw0 + 4]; + wk1i = w[nw0 + 5]; + w[nw1 + 2] = wk1r; + w[nw1 + 3] = wk1i; + } else if (nwh > 4) { + wk1r = w[nw0 + 4]; + wk3r = w[nw0 + 6]; + w[nw1 + 2] = 0.5 / wk1r; + w[nw1 + 3] = 0.5 / wk3r; + for (j = 4; j < nwh; j += 4) { + wk1r = w[nw0 + 2 * j]; + wk1i = w[nw0 + 2 * j + 1]; + wk3r = w[nw0 + 2 * j + 2]; + wk3i = w[nw0 + 2 * j + 3]; + w[nw1 + j] = wk1r; + w[nw1 + j + 1] = wk1i; + w[nw1 + j + 2] = wk3r; + w[nw1 + j + 3] = wk3i; + } + } + nw0 = nw1; + } + } +} + + +void makeipt(int nw, int *ip) { + int j, l, m, m2, p, q; + + ip[2] = 0; + ip[3] = 16; + m = 2; + for (l = nw; l > 32; l >>= 2) { + m2 = m << 1; + q = m2 << 3; + for (j = m; j < m2; j++) { + p = ip[j] << 2; + ip[m + j] = p; + ip[m2 + j] = p + q; + } + m = m2; + } +} + + +void makect(int nc, int *ip, double *c) { + int j, nch; + double delta; + + ip[1] = nc; + if (nc > 1) { + nch = nc >> 1; + delta = atan(1.0) / nch; + c[0] = cos(delta * nch); + c[nch] = 0.5 * c[0]; + for (j = 1; j < nch; j++) { + c[j] = 0.5 * cos(delta * j); + c[nc - j] = 0.5 * sin(delta * j); + } + } +} + + +/* -------- child routines -------- */ + + +#ifdef USE_CDFT_PTHREADS +#define USE_CDFT_THREADS +#ifndef CDFT_THREADS_BEGIN_N +#define CDFT_THREADS_BEGIN_N 8192 +#endif +#ifndef CDFT_4THREADS_BEGIN_N +#define CDFT_4THREADS_BEGIN_N 65536 +#endif +#include +#include +#include +#define cdft_thread_t pthread_t +#define cdft_thread_create(thp, func, argp) \ + { \ + if (pthread_create(thp, NULL, func, (void *)argp) != 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#define cdft_thread_wait(th) \ + { \ + if (pthread_join(th, NULL) != 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#endif /* USE_CDFT_PTHREADS */ + + +#ifdef USE_CDFT_WINTHREADS +#define USE_CDFT_THREADS +#ifndef CDFT_THREADS_BEGIN_N +#define CDFT_THREADS_BEGIN_N 32768 +#endif +#ifndef CDFT_4THREADS_BEGIN_N +#define CDFT_4THREADS_BEGIN_N 524288 +#endif +#include +#include +#include +#define cdft_thread_t HANDLE +#define cdft_thread_create(thp, func, argp) \ + { \ + DWORD thid; \ + *(thp) = CreateThread( \ + NULL, 0, (LPTHREAD_START_ROUTINE)func, (LPVOID)argp, 0, &thid); \ + if (*(thp) == 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#define cdft_thread_wait(th) \ + { \ + WaitForSingleObject(th, INFINITE); \ + CloseHandle(th); \ + } +#endif /* USE_CDFT_WINTHREADS */ + + +void cftfsub(int n, double *a, int *ip, int nw, double *w) { + void bitrv2(int n, int *ip, double *a); + void bitrv216(double *a); + void bitrv208(double *a); + void cftf1st(int n, double *a, double *w); + void cftrec4(int n, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftfx41(int n, double *a, int nw, double *w); + void cftf161(double *a, double *w); + void cftf081(double *a, double *w); + void cftf040(double *a); + void cftx020(double *a); +#ifdef USE_CDFT_THREADS + void cftrec4_th(int n, double *a, int nw, double *w); +#endif /* USE_CDFT_THREADS */ + + if (n > 8) { + if (n > 32) { + cftf1st(n, a, &w[nw - (n >> 2)]); +#ifdef USE_CDFT_THREADS + if (n > CDFT_THREADS_BEGIN_N) { + cftrec4_th(n, a, nw, w); + } else +#endif /* USE_CDFT_THREADS */ + if (n > 512) { + cftrec4(n, a, nw, w); + } else if (n > 128) { + cftleaf(n, 1, a, nw, w); + } else { + cftfx41(n, a, nw, w); + } + bitrv2(n, ip, a); + } else if (n == 32) { + cftf161(a, &w[nw - 8]); + bitrv216(a); + } else { + cftf081(a, w); + bitrv208(a); + } + } else if (n == 8) { + cftf040(a); + } else if (n == 4) { + cftx020(a); + } +} + + +void cftbsub(int n, double *a, int *ip, int nw, double *w) { + void bitrv2conj(int n, int *ip, double *a); + void bitrv216neg(double *a); + void bitrv208neg(double *a); + void cftb1st(int n, double *a, double *w); + void cftrec4(int n, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftfx41(int n, double *a, int nw, double *w); + void cftf161(double *a, double *w); + void cftf081(double *a, double *w); + void cftb040(double *a); + void cftx020(double *a); +#ifdef USE_CDFT_THREADS + void cftrec4_th(int n, double *a, int nw, double *w); +#endif /* USE_CDFT_THREADS */ + + if (n > 8) { + if (n > 32) { + cftb1st(n, a, &w[nw - (n >> 2)]); +#ifdef USE_CDFT_THREADS + if (n > CDFT_THREADS_BEGIN_N) { + cftrec4_th(n, a, nw, w); + } else +#endif /* USE_CDFT_THREADS */ + if (n > 512) { + cftrec4(n, a, nw, w); + } else if (n > 128) { + cftleaf(n, 1, a, nw, w); + } else { + cftfx41(n, a, nw, w); + } + bitrv2conj(n, ip, a); + } else if (n == 32) { + cftf161(a, &w[nw - 8]); + bitrv216neg(a); + } else { + cftf081(a, w); + bitrv208neg(a); + } + } else if (n == 8) { + cftb040(a); + } else if (n == 4) { + cftx020(a); + } +} + + +void bitrv2(int n, int *ip, double *a) { + int j, j1, k, k1, l, m, nh, nm; + double xr, xi, yr, yi; + + m = 1; + for (l = n >> 2; l > 8; l >>= 2) { + m <<= 1; + } + nh = n >> 1; + nm = 4 * m; + if (l == 8) { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + 2 * ip[m + k]; + k1 = 4 * k + 2 * ip[m + j]; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + 2 * ip[m + k]; + j1 = k1 + 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= 2; + k1 -= nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh + 2; + k1 += nh + 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh - nm; + k1 += 2 * nm - 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + } else { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + ip[m + k]; + k1 = 4 * k + ip[m + j]; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + ip[m + k]; + j1 = k1 + 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + } +} + + +void bitrv2conj(int n, int *ip, double *a) { + int j, j1, k, k1, l, m, nh, nm; + double xr, xi, yr, yi; + + m = 1; + for (l = n >> 2; l > 8; l >>= 2) { + m <<= 1; + } + nh = n >> 1; + nm = 4 * m; + if (l == 8) { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + 2 * ip[m + k]; + k1 = 4 * k + 2 * ip[m + j]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + 2 * ip[m + k]; + j1 = k1 + 2; + k1 += nh; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= 2; + k1 -= nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh + 2; + k1 += nh + 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh - nm; + k1 += 2 * nm - 2; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + } + } else { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + ip[m + k]; + k1 = 4 * k + ip[m + j]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + ip[m + k]; + j1 = k1 + 2; + k1 += nh; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + j1 += nm; + k1 += nm; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + } + } +} + + +void bitrv216(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x7r, x7i, x8r, x8i, + x10r, x10i, x11r, x11i, x12r, x12i, x13r, x13i, x14r, x14i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x7r = a[14]; + x7i = a[15]; + x8r = a[16]; + x8i = a[17]; + x10r = a[20]; + x10i = a[21]; + x11r = a[22]; + x11i = a[23]; + x12r = a[24]; + x12i = a[25]; + x13r = a[26]; + x13i = a[27]; + x14r = a[28]; + x14i = a[29]; + a[2] = x8r; + a[3] = x8i; + a[4] = x4r; + a[5] = x4i; + a[6] = x12r; + a[7] = x12i; + a[8] = x2r; + a[9] = x2i; + a[10] = x10r; + a[11] = x10i; + a[14] = x14r; + a[15] = x14i; + a[16] = x1r; + a[17] = x1i; + a[20] = x5r; + a[21] = x5i; + a[22] = x13r; + a[23] = x13i; + a[24] = x3r; + a[25] = x3i; + a[26] = x11r; + a[27] = x11i; + a[28] = x7r; + a[29] = x7i; +} + + +void bitrv216neg(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x6r, x6i, x7r, x7i, + x8r, x8i, x9r, x9i, x10r, x10i, x11r, x11i, x12r, x12i, x13r, x13i, + x14r, x14i, x15r, x15i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x6r = a[12]; + x6i = a[13]; + x7r = a[14]; + x7i = a[15]; + x8r = a[16]; + x8i = a[17]; + x9r = a[18]; + x9i = a[19]; + x10r = a[20]; + x10i = a[21]; + x11r = a[22]; + x11i = a[23]; + x12r = a[24]; + x12i = a[25]; + x13r = a[26]; + x13i = a[27]; + x14r = a[28]; + x14i = a[29]; + x15r = a[30]; + x15i = a[31]; + a[2] = x15r; + a[3] = x15i; + a[4] = x7r; + a[5] = x7i; + a[6] = x11r; + a[7] = x11i; + a[8] = x3r; + a[9] = x3i; + a[10] = x13r; + a[11] = x13i; + a[12] = x5r; + a[13] = x5i; + a[14] = x9r; + a[15] = x9i; + a[16] = x1r; + a[17] = x1i; + a[18] = x14r; + a[19] = x14i; + a[20] = x6r; + a[21] = x6i; + a[22] = x10r; + a[23] = x10i; + a[24] = x2r; + a[25] = x2i; + a[26] = x12r; + a[27] = x12i; + a[28] = x4r; + a[29] = x4i; + a[30] = x8r; + a[31] = x8i; +} + + +void bitrv208(double *a) { + double x1r, x1i, x3r, x3i, x4r, x4i, x6r, x6i; + + x1r = a[2]; + x1i = a[3]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x6r = a[12]; + x6i = a[13]; + a[2] = x4r; + a[3] = x4i; + a[6] = x6r; + a[7] = x6i; + a[8] = x1r; + a[9] = x1i; + a[12] = x3r; + a[13] = x3i; +} + + +void bitrv208neg(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x6r, x6i, x7r, x7i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x6r = a[12]; + x6i = a[13]; + x7r = a[14]; + x7i = a[15]; + a[2] = x7r; + a[3] = x7i; + a[4] = x3r; + a[5] = x3i; + a[6] = x5r; + a[7] = x5i; + a[8] = x1r; + a[9] = x1i; + a[10] = x6r; + a[11] = x6i; + a[12] = x2r; + a[13] = x2i; + a[14] = x4r; + a[15] = x4i; +} + + +void cftf1st(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, csc1, csc3, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = a[1] + a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = a[1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j2] = x1r - x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r + x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + csc1 = w[2]; + csc3 = w[3]; + wd1r = 1; + wd1i = 0; + wd3r = 1; + wd3i = 0; + k = 0; + for (j = 2; j < mh - 2; j += 4) { + k += 4; + wk1r = csc1 * (wd1r + w[k]); + wk1i = csc1 * (wd1i + w[k + 1]); + wk3r = csc3 * (wd3r + w[k + 2]); + wk3i = csc3 * (wd3i + w[k + 3]); + wd1r = w[k]; + wd1i = w[k + 1]; + wd3r = w[k + 2]; + wd3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = a[j + 1] + a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = a[j + 1] - a[j2 + 1]; + y0r = a[j + 2] + a[j2 + 2]; + y0i = a[j + 3] + a[j2 + 3]; + y1r = a[j + 2] - a[j2 + 2]; + y1i = a[j + 3] - a[j2 + 3]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 + 2] + a[j3 + 2]; + y2i = a[j1 + 3] + a[j3 + 3]; + y3r = a[j1 + 2] - a[j3 + 2]; + y3i = a[j1 + 3] - a[j3 + 3]; + a[j] = x0r + x2r; + a[j + 1] = x0i + x2i; + a[j + 2] = y0r + y2r; + a[j + 3] = y0i + y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j1 + 2] = y0r - y2r; + a[j1 + 3] = y0i - y2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = y1r - y3i; + x0i = y1i + y3r; + a[j2 + 2] = wd1r * x0r - wd1i * x0i; + a[j2 + 3] = wd1r * x0i + wd1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + x0r = y1r + y3i; + x0i = y1i - y3r; + a[j3 + 2] = wd3r * x0r + wd3i * x0i; + a[j3 + 3] = wd3r * x0i - wd3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + y0r = a[j0 - 2] + a[j2 - 2]; + y0i = a[j0 - 1] + a[j2 - 1]; + y1r = a[j0 - 2] - a[j2 - 2]; + y1i = a[j0 - 1] - a[j2 - 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 - 2] + a[j3 - 2]; + y2i = a[j1 - 1] + a[j3 - 1]; + y3r = a[j1 - 2] - a[j3 - 2]; + y3i = a[j1 - 1] - a[j3 - 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j0 - 2] = y0r + y2r; + a[j0 - 1] = y0i + y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j1 - 2] = y0r - y2r; + a[j1 - 1] = y0i - y2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = y1r - y3i; + x0i = y1i + y3r; + a[j2 - 2] = wd1i * x0r - wd1r * x0i; + a[j2 - 1] = wd1i * x0i + wd1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + x0r = y1r + y3i; + x0i = y1i - y3r; + a[j3 - 2] = wd3i * x0r + wd3r * x0i; + a[j3 - 1] = wd3i * x0i - wd3r * x0r; + } + wk1r = csc1 * (wd1r + wn4r); + wk1i = csc1 * (wd1i + wn4r); + wk3r = csc3 * (wd3r - wn4r); + wk3i = csc3 * (wd3i - wn4r); + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0 - 2] + a[j2 - 2]; + x0i = a[j0 - 1] + a[j2 - 1]; + x1r = a[j0 - 2] - a[j2 - 2]; + x1i = a[j0 - 1] - a[j2 - 1]; + x2r = a[j1 - 2] + a[j3 - 2]; + x2i = a[j1 - 1] + a[j3 - 1]; + x3r = a[j1 - 2] - a[j3 - 2]; + x3i = a[j1 - 1] - a[j3 - 1]; + a[j0 - 2] = x0r + x2r; + a[j0 - 1] = x0i + x2i; + a[j1 - 2] = x0r - x2r; + a[j1 - 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2 - 2] = wk1r * x0r - wk1i * x0i; + a[j2 - 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3 - 2] = wk3r * x0r + wk3i * x0i; + a[j3 - 1] = wk3r * x0i - wk3i * x0r; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); + x0r = a[j0 + 2] + a[j2 + 2]; + x0i = a[j0 + 3] + a[j2 + 3]; + x1r = a[j0 + 2] - a[j2 + 2]; + x1i = a[j0 + 3] - a[j2 + 3]; + x2r = a[j1 + 2] + a[j3 + 2]; + x2i = a[j1 + 3] + a[j3 + 3]; + x3r = a[j1 + 2] - a[j3 + 2]; + x3i = a[j1 + 3] - a[j3 + 3]; + a[j0 + 2] = x0r + x2r; + a[j0 + 3] = x0i + x2i; + a[j1 + 2] = x0r - x2r; + a[j1 + 3] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2 + 2] = wk1i * x0r - wk1r * x0i; + a[j2 + 3] = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3 + 2] = wk3i * x0r + wk3r * x0i; + a[j3 + 3] = wk3i * x0i - wk3r * x0r; +} + + +void cftb1st(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, csc1, csc3, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = -a[1] - a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = -a[1] + a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i - x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j2] = x1r + x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r - x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + csc1 = w[2]; + csc3 = w[3]; + wd1r = 1; + wd1i = 0; + wd3r = 1; + wd3i = 0; + k = 0; + for (j = 2; j < mh - 2; j += 4) { + k += 4; + wk1r = csc1 * (wd1r + w[k]); + wk1i = csc1 * (wd1i + w[k + 1]); + wk3r = csc3 * (wd3r + w[k + 2]); + wk3i = csc3 * (wd3i + w[k + 3]); + wd1r = w[k]; + wd1i = w[k + 1]; + wd3r = w[k + 2]; + wd3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = -a[j + 1] - a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = -a[j + 1] + a[j2 + 1]; + y0r = a[j + 2] + a[j2 + 2]; + y0i = -a[j + 3] - a[j2 + 3]; + y1r = a[j + 2] - a[j2 + 2]; + y1i = -a[j + 3] + a[j2 + 3]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 + 2] + a[j3 + 2]; + y2i = a[j1 + 3] + a[j3 + 3]; + y3r = a[j1 + 2] - a[j3 + 2]; + y3i = a[j1 + 3] - a[j3 + 3]; + a[j] = x0r + x2r; + a[j + 1] = x0i - x2i; + a[j + 2] = y0r + y2r; + a[j + 3] = y0i - y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j1 + 2] = y0r - y2r; + a[j1 + 3] = y0i + y2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = y1r + y3i; + x0i = y1i + y3r; + a[j2 + 2] = wd1r * x0r - wd1i * x0i; + a[j2 + 3] = wd1r * x0i + wd1i * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + x0r = y1r - y3i; + x0i = y1i - y3r; + a[j3 + 2] = wd3r * x0r + wd3i * x0i; + a[j3 + 3] = wd3r * x0i - wd3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = -a[j0 + 1] - a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = -a[j0 + 1] + a[j2 + 1]; + y0r = a[j0 - 2] + a[j2 - 2]; + y0i = -a[j0 - 1] - a[j2 - 1]; + y1r = a[j0 - 2] - a[j2 - 2]; + y1i = -a[j0 - 1] + a[j2 - 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 - 2] + a[j3 - 2]; + y2i = a[j1 - 1] + a[j3 - 1]; + y3r = a[j1 - 2] - a[j3 - 2]; + y3i = a[j1 - 1] - a[j3 - 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i - x2i; + a[j0 - 2] = y0r + y2r; + a[j0 - 1] = y0i - y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j1 - 2] = y0r - y2r; + a[j1 - 1] = y0i + y2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = y1r + y3i; + x0i = y1i + y3r; + a[j2 - 2] = wd1i * x0r - wd1r * x0i; + a[j2 - 1] = wd1i * x0i + wd1r * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + x0r = y1r - y3i; + x0i = y1i - y3r; + a[j3 - 2] = wd3i * x0r + wd3r * x0i; + a[j3 - 1] = wd3i * x0i - wd3r * x0r; + } + wk1r = csc1 * (wd1r + wn4r); + wk1i = csc1 * (wd1i + wn4r); + wk3r = csc3 * (wd3r - wn4r); + wk3i = csc3 * (wd3i - wn4r); + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0 - 2] + a[j2 - 2]; + x0i = -a[j0 - 1] - a[j2 - 1]; + x1r = a[j0 - 2] - a[j2 - 2]; + x1i = -a[j0 - 1] + a[j2 - 1]; + x2r = a[j1 - 2] + a[j3 - 2]; + x2i = a[j1 - 1] + a[j3 - 1]; + x3r = a[j1 - 2] - a[j3 - 2]; + x3i = a[j1 - 1] - a[j3 - 1]; + a[j0 - 2] = x0r + x2r; + a[j0 - 1] = x0i - x2i; + a[j1 - 2] = x0r - x2r; + a[j1 - 1] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2 - 2] = wk1r * x0r - wk1i * x0i; + a[j2 - 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3 - 2] = wk3r * x0r + wk3i * x0i; + a[j3 - 1] = wk3r * x0i - wk3i * x0r; + x0r = a[j0] + a[j2]; + x0i = -a[j0 + 1] - a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = -a[j0 + 1] + a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i - x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); + x0r = a[j0 + 2] + a[j2 + 2]; + x0i = -a[j0 + 3] - a[j2 + 3]; + x1r = a[j0 + 2] - a[j2 + 2]; + x1i = -a[j0 + 3] + a[j2 + 3]; + x2r = a[j1 + 2] + a[j3 + 2]; + x2i = a[j1 + 3] + a[j3 + 3]; + x3r = a[j1 + 2] - a[j3 + 2]; + x3i = a[j1 + 3] - a[j3 + 3]; + a[j0 + 2] = x0r + x2r; + a[j0 + 3] = x0i - x2i; + a[j1 + 2] = x0r - x2r; + a[j1 + 3] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2 + 2] = wk1i * x0r - wk1r * x0i; + a[j2 + 3] = wk1i * x0i + wk1r * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3 + 2] = wk3i * x0r + wk3r * x0i; + a[j3 + 3] = wk3i * x0i - wk3r * x0r; +} + + +#ifdef USE_CDFT_THREADS +struct cdft_arg_st { + int n0; + int n; + double *a; + int nw; + double *w; +}; +typedef struct cdft_arg_st cdft_arg_t; + + +void cftrec4_th(int n, double *a, int nw, double *w) { + void *cftrec1_th(void *p); + void *cftrec2_th(void *p); + int i, idiv4, m, nthread; + cdft_thread_t th[4]; + cdft_arg_t ag[4]; + + nthread = 2; + idiv4 = 0; + m = n >> 1; + if (n > CDFT_4THREADS_BEGIN_N) { + nthread = 4; + idiv4 = 1; + m >>= 1; + } + for (i = 0; i < nthread; i++) { + ag[i].n0 = n; + ag[i].n = m; + ag[i].a = &a[i * m]; + ag[i].nw = nw; + ag[i].w = w; + if (i != idiv4) { + cdft_thread_create(&th[i], cftrec1_th, &ag[i]); + } else { + cdft_thread_create(&th[i], cftrec2_th, &ag[i]); + } + } + for (i = 0; i < nthread; i++) { + cdft_thread_wait(th[i]); + } +} + + +void *cftrec1_th(void *p) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl1(int n, double *a, double *w); + int isplt, j, k, m, n, n0, nw; + double *a, *w; + + n0 = ((cdft_arg_t *)p)->n0; + n = ((cdft_arg_t *)p)->n; + a = ((cdft_arg_t *)p)->a; + nw = ((cdft_arg_t *)p)->nw; + w = ((cdft_arg_t *)p)->w; + m = n0; + while (m > 512) { + m >>= 2; + cftmdl1(m, &a[n - m], &w[nw - (m >> 1)]); + } + cftleaf(m, 1, &a[n - m], nw, w); + k = 0; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } + return (void *)0; +} + + +void *cftrec2_th(void *p) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl2(int n, double *a, double *w); + int isplt, j, k, m, n, n0, nw; + double *a, *w; + + n0 = ((cdft_arg_t *)p)->n0; + n = ((cdft_arg_t *)p)->n; + a = ((cdft_arg_t *)p)->a; + nw = ((cdft_arg_t *)p)->nw; + w = ((cdft_arg_t *)p)->w; + k = 1; + m = n0; + while (m > 512) { + m >>= 2; + k <<= 2; + cftmdl2(m, &a[n - m], &w[nw - m]); + } + cftleaf(m, 0, &a[n - m], nw, w); + k >>= 1; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } + return (void *)0; +} +#endif /* USE_CDFT_THREADS */ + + +void cftrec4(int n, double *a, int nw, double *w) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl1(int n, double *a, double *w); + int isplt, j, k, m; + + m = n; + while (m > 512) { + m >>= 2; + cftmdl1(m, &a[n - m], &w[nw - (m >> 1)]); + } + cftleaf(m, 1, &a[n - m], nw, w); + k = 0; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } +} + + +int cfttree(int n, int j, int k, double *a, int nw, double *w) { + void cftmdl1(int n, double *a, double *w); + void cftmdl2(int n, double *a, double *w); + int i, isplt, m; + + if ((k & 3) != 0) { + isplt = k & 1; + if (isplt != 0) { + cftmdl1(n, &a[j - n], &w[nw - (n >> 1)]); + } else { + cftmdl2(n, &a[j - n], &w[nw - n]); + } + } else { + m = n; + for (i = k; (i & 3) == 0; i >>= 2) { + m <<= 2; + } + isplt = i & 1; + if (isplt != 0) { + while (m > 128) { + cftmdl1(m, &a[j - m], &w[nw - (m >> 1)]); + m >>= 2; + } + } else { + while (m > 128) { + cftmdl2(m, &a[j - m], &w[nw - m]); + m >>= 2; + } + } + } + return isplt; +} + + +void cftleaf(int n, int isplt, double *a, int nw, double *w) { + void cftmdl1(int n, double *a, double *w); + void cftmdl2(int n, double *a, double *w); + void cftf161(double *a, double *w); + void cftf162(double *a, double *w); + void cftf081(double *a, double *w); + void cftf082(double *a, double *w); + + if (n == 512) { + cftmdl1(128, a, &w[nw - 64]); + cftf161(a, &w[nw - 8]); + cftf162(&a[32], &w[nw - 32]); + cftf161(&a[64], &w[nw - 8]); + cftf161(&a[96], &w[nw - 8]); + cftmdl2(128, &a[128], &w[nw - 128]); + cftf161(&a[128], &w[nw - 8]); + cftf162(&a[160], &w[nw - 32]); + cftf161(&a[192], &w[nw - 8]); + cftf162(&a[224], &w[nw - 32]); + cftmdl1(128, &a[256], &w[nw - 64]); + cftf161(&a[256], &w[nw - 8]); + cftf162(&a[288], &w[nw - 32]); + cftf161(&a[320], &w[nw - 8]); + cftf161(&a[352], &w[nw - 8]); + if (isplt != 0) { + cftmdl1(128, &a[384], &w[nw - 64]); + cftf161(&a[480], &w[nw - 8]); + } else { + cftmdl2(128, &a[384], &w[nw - 128]); + cftf162(&a[480], &w[nw - 32]); + } + cftf161(&a[384], &w[nw - 8]); + cftf162(&a[416], &w[nw - 32]); + cftf161(&a[448], &w[nw - 8]); + } else { + cftmdl1(64, a, &w[nw - 32]); + cftf081(a, &w[nw - 8]); + cftf082(&a[16], &w[nw - 8]); + cftf081(&a[32], &w[nw - 8]); + cftf081(&a[48], &w[nw - 8]); + cftmdl2(64, &a[64], &w[nw - 64]); + cftf081(&a[64], &w[nw - 8]); + cftf082(&a[80], &w[nw - 8]); + cftf081(&a[96], &w[nw - 8]); + cftf082(&a[112], &w[nw - 8]); + cftmdl1(64, &a[128], &w[nw - 32]); + cftf081(&a[128], &w[nw - 8]); + cftf082(&a[144], &w[nw - 8]); + cftf081(&a[160], &w[nw - 8]); + cftf081(&a[176], &w[nw - 8]); + if (isplt != 0) { + cftmdl1(64, &a[192], &w[nw - 32]); + cftf081(&a[240], &w[nw - 8]); + } else { + cftmdl2(64, &a[192], &w[nw - 64]); + cftf082(&a[240], &w[nw - 8]); + } + cftf081(&a[192], &w[nw - 8]); + cftf082(&a[208], &w[nw - 8]); + cftf081(&a[224], &w[nw - 8]); + } +} + + +void cftmdl1(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, wk1r, wk1i, wk3r, wk3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = a[1] + a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = a[1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j2] = x1r - x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r + x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + k = 0; + for (j = 2; j < mh; j += 2) { + k += 4; + wk1r = w[k]; + wk1i = w[k + 1]; + wk3r = w[k + 2]; + wk3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = a[j + 1] + a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = a[j + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j] = x0r + x2r; + a[j + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + } + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); +} + + +void cftmdl2(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, kr, m, mh; + double wn4r, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y2r, y2i; + + mh = n >> 3; + m = 2 * mh; + wn4r = w[1]; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] - a[j2 + 1]; + x0i = a[1] + a[j2]; + x1r = a[0] + a[j2 + 1]; + x1i = a[1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wn4r * (x2r - x2i); + y0i = wn4r * (x2i + x2r); + a[0] = x0r + y0r; + a[1] = x0i + y0i; + a[j1] = x0r - y0r; + a[j1 + 1] = x0i - y0i; + y0r = wn4r * (x3r - x3i); + y0i = wn4r * (x3i + x3r); + a[j2] = x1r - y0i; + a[j2 + 1] = x1i + y0r; + a[j3] = x1r + y0i; + a[j3 + 1] = x1i - y0r; + k = 0; + kr = 2 * m; + for (j = 2; j < mh; j += 2) { + k += 4; + wk1r = w[k]; + wk1i = w[k + 1]; + wk3r = w[k + 2]; + wk3i = w[k + 3]; + kr -= 4; + wd1i = w[kr]; + wd1r = w[kr + 1]; + wd3i = w[kr + 2]; + wd3r = w[kr + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] - a[j2 + 1]; + x0i = a[j + 1] + a[j2]; + x1r = a[j] + a[j2 + 1]; + x1i = a[j + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wk1r * x0r - wk1i * x0i; + y0i = wk1r * x0i + wk1i * x0r; + y2r = wd1r * x2r - wd1i * x2i; + y2i = wd1r * x2i + wd1i * x2r; + a[j] = y0r + y2r; + a[j + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wk3r * x1r + wk3i * x1i; + y0i = wk3r * x1i - wk3i * x1r; + y2r = wd3r * x3r + wd3i * x3i; + y2i = wd3r * x3i - wd3i * x3r; + a[j2] = y0r + y2r; + a[j2 + 1] = y0i + y2i; + a[j3] = y0r - y2r; + a[j3 + 1] = y0i - y2i; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] - a[j2 + 1]; + x0i = a[j0 + 1] + a[j2]; + x1r = a[j0] + a[j2 + 1]; + x1i = a[j0 + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wd1i * x0r - wd1r * x0i; + y0i = wd1i * x0i + wd1r * x0r; + y2r = wk1i * x2r - wk1r * x2i; + y2i = wk1i * x2i + wk1r * x2r; + a[j0] = y0r + y2r; + a[j0 + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wd3i * x1r + wd3r * x1i; + y0i = wd3i * x1i - wd3r * x1r; + y2r = wk3i * x3r + wk3r * x3i; + y2i = wk3i * x3i - wk3r * x3r; + a[j2] = y0r + y2r; + a[j2 + 1] = y0i + y2i; + a[j3] = y0r - y2r; + a[j3 + 1] = y0i - y2i; + } + wk1r = w[m]; + wk1i = w[m + 1]; + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] - a[j2 + 1]; + x0i = a[j0 + 1] + a[j2]; + x1r = a[j0] + a[j2 + 1]; + x1i = a[j0 + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wk1r * x0r - wk1i * x0i; + y0i = wk1r * x0i + wk1i * x0r; + y2r = wk1i * x2r - wk1r * x2i; + y2i = wk1i * x2i + wk1r * x2r; + a[j0] = y0r + y2r; + a[j0 + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wk1i * x1r - wk1r * x1i; + y0i = wk1i * x1i + wk1r * x1r; + y2r = wk1r * x3r - wk1i * x3i; + y2i = wk1r * x3i + wk1i * x3r; + a[j2] = y0r - y2r; + a[j2 + 1] = y0i - y2i; + a[j3] = y0r + y2r; + a[j3 + 1] = y0i + y2i; +} + + +void cftfx41(int n, double *a, int nw, double *w) { + void cftf161(double *a, double *w); + void cftf162(double *a, double *w); + void cftf081(double *a, double *w); + void cftf082(double *a, double *w); + + if (n == 128) { + cftf161(a, &w[nw - 8]); + cftf162(&a[32], &w[nw - 32]); + cftf161(&a[64], &w[nw - 8]); + cftf161(&a[96], &w[nw - 8]); + } else { + cftf081(a, &w[nw - 8]); + cftf082(&a[16], &w[nw - 8]); + cftf081(&a[32], &w[nw - 8]); + cftf081(&a[48], &w[nw - 8]); + } +} + + +void cftf161(double *a, double *w) { + double wn4r, wk1r, wk1i, x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, + y1r, y1i, y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i, + y8r, y8i, y9r, y9i, y10r, y10i, y11r, y11i, y12r, y12i, y13r, y13i, + y14r, y14i, y15r, y15i; + + wn4r = w[1]; + wk1r = w[2]; + wk1i = w[3]; + x0r = a[0] + a[16]; + x0i = a[1] + a[17]; + x1r = a[0] - a[16]; + x1i = a[1] - a[17]; + x2r = a[8] + a[24]; + x2i = a[9] + a[25]; + x3r = a[8] - a[24]; + x3i = a[9] - a[25]; + y0r = x0r + x2r; + y0i = x0i + x2i; + y4r = x0r - x2r; + y4i = x0i - x2i; + y8r = x1r - x3i; + y8i = x1i + x3r; + y12r = x1r + x3i; + y12i = x1i - x3r; + x0r = a[2] + a[18]; + x0i = a[3] + a[19]; + x1r = a[2] - a[18]; + x1i = a[3] - a[19]; + x2r = a[10] + a[26]; + x2i = a[11] + a[27]; + x3r = a[10] - a[26]; + x3i = a[11] - a[27]; + y1r = x0r + x2r; + y1i = x0i + x2i; + y5r = x0r - x2r; + y5i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y9r = wk1r * x0r - wk1i * x0i; + y9i = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + y13r = wk1i * x0r - wk1r * x0i; + y13i = wk1i * x0i + wk1r * x0r; + x0r = a[4] + a[20]; + x0i = a[5] + a[21]; + x1r = a[4] - a[20]; + x1i = a[5] - a[21]; + x2r = a[12] + a[28]; + x2i = a[13] + a[29]; + x3r = a[12] - a[28]; + x3i = a[13] - a[29]; + y2r = x0r + x2r; + y2i = x0i + x2i; + y6r = x0r - x2r; + y6i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y10r = wn4r * (x0r - x0i); + y10i = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + y14r = wn4r * (x0r + x0i); + y14i = wn4r * (x0i - x0r); + x0r = a[6] + a[22]; + x0i = a[7] + a[23]; + x1r = a[6] - a[22]; + x1i = a[7] - a[23]; + x2r = a[14] + a[30]; + x2i = a[15] + a[31]; + x3r = a[14] - a[30]; + x3i = a[15] - a[31]; + y3r = x0r + x2r; + y3i = x0i + x2i; + y7r = x0r - x2r; + y7i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y11r = wk1i * x0r - wk1r * x0i; + y11i = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + y15r = wk1r * x0r - wk1i * x0i; + y15i = wk1r * x0i + wk1i * x0r; + x0r = y12r - y14r; + x0i = y12i - y14i; + x1r = y12r + y14r; + x1i = y12i + y14i; + x2r = y13r - y15r; + x2i = y13i - y15i; + x3r = y13r + y15r; + x3i = y13i + y15i; + a[24] = x0r + x2r; + a[25] = x0i + x2i; + a[26] = x0r - x2r; + a[27] = x0i - x2i; + a[28] = x1r - x3i; + a[29] = x1i + x3r; + a[30] = x1r + x3i; + a[31] = x1i - x3r; + x0r = y8r + y10r; + x0i = y8i + y10i; + x1r = y8r - y10r; + x1i = y8i - y10i; + x2r = y9r + y11r; + x2i = y9i + y11i; + x3r = y9r - y11r; + x3i = y9i - y11i; + a[16] = x0r + x2r; + a[17] = x0i + x2i; + a[18] = x0r - x2r; + a[19] = x0i - x2i; + a[20] = x1r - x3i; + a[21] = x1i + x3r; + a[22] = x1r + x3i; + a[23] = x1i - x3r; + x0r = y5r - y7i; + x0i = y5i + y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + x0r = y5r + y7i; + x0i = y5i - y7r; + x3r = wn4r * (x0r - x0i); + x3i = wn4r * (x0i + x0r); + x0r = y4r - y6i; + x0i = y4i + y6r; + x1r = y4r + y6i; + x1i = y4i - y6r; + a[8] = x0r + x2r; + a[9] = x0i + x2i; + a[10] = x0r - x2r; + a[11] = x0i - x2i; + a[12] = x1r - x3i; + a[13] = x1i + x3r; + a[14] = x1r + x3i; + a[15] = x1i - x3r; + x0r = y0r + y2r; + x0i = y0i + y2i; + x1r = y0r - y2r; + x1i = y0i - y2i; + x2r = y1r + y3r; + x2i = y1i + y3i; + x3r = y1r - y3r; + x3i = y1i - y3i; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x0r - x2r; + a[3] = x0i - x2i; + a[4] = x1r - x3i; + a[5] = x1i + x3r; + a[6] = x1r + x3i; + a[7] = x1i - x3r; +} + + +void cftf162(double *a, double *w) { + double wn4r, wk1r, wk1i, wk2r, wk2i, wk3r, wk3i, x0r, x0i, x1r, x1i, x2r, + x2i, y0r, y0i, y1r, y1i, y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, + y6i, y7r, y7i, y8r, y8i, y9r, y9i, y10r, y10i, y11r, y11i, y12r, y12i, + y13r, y13i, y14r, y14i, y15r, y15i; + + wn4r = w[1]; + wk1r = w[4]; + wk1i = w[5]; + wk3r = w[6]; + wk3i = -w[7]; + wk2r = w[8]; + wk2i = w[9]; + x1r = a[0] - a[17]; + x1i = a[1] + a[16]; + x0r = a[8] - a[25]; + x0i = a[9] + a[24]; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + y0r = x1r + x2r; + y0i = x1i + x2i; + y4r = x1r - x2r; + y4i = x1i - x2i; + x1r = a[0] + a[17]; + x1i = a[1] - a[16]; + x0r = a[8] + a[25]; + x0i = a[9] - a[24]; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + y8r = x1r - x2i; + y8i = x1i + x2r; + y12r = x1r + x2i; + y12i = x1i - x2r; + x0r = a[2] - a[19]; + x0i = a[3] + a[18]; + x1r = wk1r * x0r - wk1i * x0i; + x1i = wk1r * x0i + wk1i * x0r; + x0r = a[10] - a[27]; + x0i = a[11] + a[26]; + x2r = wk3i * x0r - wk3r * x0i; + x2i = wk3i * x0i + wk3r * x0r; + y1r = x1r + x2r; + y1i = x1i + x2i; + y5r = x1r - x2r; + y5i = x1i - x2i; + x0r = a[2] + a[19]; + x0i = a[3] - a[18]; + x1r = wk3r * x0r - wk3i * x0i; + x1i = wk3r * x0i + wk3i * x0r; + x0r = a[10] + a[27]; + x0i = a[11] - a[26]; + x2r = wk1r * x0r + wk1i * x0i; + x2i = wk1r * x0i - wk1i * x0r; + y9r = x1r - x2r; + y9i = x1i - x2i; + y13r = x1r + x2r; + y13i = x1i + x2i; + x0r = a[4] - a[21]; + x0i = a[5] + a[20]; + x1r = wk2r * x0r - wk2i * x0i; + x1i = wk2r * x0i + wk2i * x0r; + x0r = a[12] - a[29]; + x0i = a[13] + a[28]; + x2r = wk2i * x0r - wk2r * x0i; + x2i = wk2i * x0i + wk2r * x0r; + y2r = x1r + x2r; + y2i = x1i + x2i; + y6r = x1r - x2r; + y6i = x1i - x2i; + x0r = a[4] + a[21]; + x0i = a[5] - a[20]; + x1r = wk2i * x0r - wk2r * x0i; + x1i = wk2i * x0i + wk2r * x0r; + x0r = a[12] + a[29]; + x0i = a[13] - a[28]; + x2r = wk2r * x0r - wk2i * x0i; + x2i = wk2r * x0i + wk2i * x0r; + y10r = x1r - x2r; + y10i = x1i - x2i; + y14r = x1r + x2r; + y14i = x1i + x2i; + x0r = a[6] - a[23]; + x0i = a[7] + a[22]; + x1r = wk3r * x0r - wk3i * x0i; + x1i = wk3r * x0i + wk3i * x0r; + x0r = a[14] - a[31]; + x0i = a[15] + a[30]; + x2r = wk1i * x0r - wk1r * x0i; + x2i = wk1i * x0i + wk1r * x0r; + y3r = x1r + x2r; + y3i = x1i + x2i; + y7r = x1r - x2r; + y7i = x1i - x2i; + x0r = a[6] + a[23]; + x0i = a[7] - a[22]; + x1r = wk1i * x0r + wk1r * x0i; + x1i = wk1i * x0i - wk1r * x0r; + x0r = a[14] + a[31]; + x0i = a[15] - a[30]; + x2r = wk3i * x0r - wk3r * x0i; + x2i = wk3i * x0i + wk3r * x0r; + y11r = x1r + x2r; + y11i = x1i + x2i; + y15r = x1r - x2r; + y15i = x1i - x2i; + x1r = y0r + y2r; + x1i = y0i + y2i; + x2r = y1r + y3r; + x2i = y1i + y3i; + a[0] = x1r + x2r; + a[1] = x1i + x2i; + a[2] = x1r - x2r; + a[3] = x1i - x2i; + x1r = y0r - y2r; + x1i = y0i - y2i; + x2r = y1r - y3r; + x2i = y1i - y3i; + a[4] = x1r - x2i; + a[5] = x1i + x2r; + a[6] = x1r + x2i; + a[7] = x1i - x2r; + x1r = y4r - y6i; + x1i = y4i + y6r; + x0r = y5r - y7i; + x0i = y5i + y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[8] = x1r + x2r; + a[9] = x1i + x2i; + a[10] = x1r - x2r; + a[11] = x1i - x2i; + x1r = y4r + y6i; + x1i = y4i - y6r; + x0r = y5r + y7i; + x0i = y5i - y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[12] = x1r - x2i; + a[13] = x1i + x2r; + a[14] = x1r + x2i; + a[15] = x1i - x2r; + x1r = y8r + y10r; + x1i = y8i + y10i; + x2r = y9r - y11r; + x2i = y9i - y11i; + a[16] = x1r + x2r; + a[17] = x1i + x2i; + a[18] = x1r - x2r; + a[19] = x1i - x2i; + x1r = y8r - y10r; + x1i = y8i - y10i; + x2r = y9r + y11r; + x2i = y9i + y11i; + a[20] = x1r - x2i; + a[21] = x1i + x2r; + a[22] = x1r + x2i; + a[23] = x1i - x2r; + x1r = y12r - y14i; + x1i = y12i + y14r; + x0r = y13r + y15i; + x0i = y13i - y15r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[24] = x1r + x2r; + a[25] = x1i + x2i; + a[26] = x1r - x2r; + a[27] = x1i - x2i; + x1r = y12r + y14i; + x1i = y12i - y14r; + x0r = y13r - y15i; + x0i = y13i + y15r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[28] = x1r - x2i; + a[29] = x1i + x2r; + a[30] = x1r + x2i; + a[31] = x1i - x2r; +} + + +void cftf081(double *a, double *w) { + double wn4r, x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, + y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i; + + wn4r = w[1]; + x0r = a[0] + a[8]; + x0i = a[1] + a[9]; + x1r = a[0] - a[8]; + x1i = a[1] - a[9]; + x2r = a[4] + a[12]; + x2i = a[5] + a[13]; + x3r = a[4] - a[12]; + x3i = a[5] - a[13]; + y0r = x0r + x2r; + y0i = x0i + x2i; + y2r = x0r - x2r; + y2i = x0i - x2i; + y1r = x1r - x3i; + y1i = x1i + x3r; + y3r = x1r + x3i; + y3i = x1i - x3r; + x0r = a[2] + a[10]; + x0i = a[3] + a[11]; + x1r = a[2] - a[10]; + x1i = a[3] - a[11]; + x2r = a[6] + a[14]; + x2i = a[7] + a[15]; + x3r = a[6] - a[14]; + x3i = a[7] - a[15]; + y4r = x0r + x2r; + y4i = x0i + x2i; + y6r = x0r - x2r; + y6i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + x2r = x1r + x3i; + x2i = x1i - x3r; + y5r = wn4r * (x0r - x0i); + y5i = wn4r * (x0r + x0i); + y7r = wn4r * (x2r - x2i); + y7i = wn4r * (x2r + x2i); + a[8] = y1r + y5r; + a[9] = y1i + y5i; + a[10] = y1r - y5r; + a[11] = y1i - y5i; + a[12] = y3r - y7i; + a[13] = y3i + y7r; + a[14] = y3r + y7i; + a[15] = y3i - y7r; + a[0] = y0r + y4r; + a[1] = y0i + y4i; + a[2] = y0r - y4r; + a[3] = y0i - y4i; + a[4] = y2r - y6i; + a[5] = y2i + y6r; + a[6] = y2r + y6i; + a[7] = y2i - y6r; +} + + +void cftf082(double *a, double *w) { + double wn4r, wk1r, wk1i, x0r, x0i, x1r, x1i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i; + + wn4r = w[1]; + wk1r = w[2]; + wk1i = w[3]; + y0r = a[0] - a[9]; + y0i = a[1] + a[8]; + y1r = a[0] + a[9]; + y1i = a[1] - a[8]; + x0r = a[4] - a[13]; + x0i = a[5] + a[12]; + y2r = wn4r * (x0r - x0i); + y2i = wn4r * (x0i + x0r); + x0r = a[4] + a[13]; + x0i = a[5] - a[12]; + y3r = wn4r * (x0r - x0i); + y3i = wn4r * (x0i + x0r); + x0r = a[2] - a[11]; + x0i = a[3] + a[10]; + y4r = wk1r * x0r - wk1i * x0i; + y4i = wk1r * x0i + wk1i * x0r; + x0r = a[2] + a[11]; + x0i = a[3] - a[10]; + y5r = wk1i * x0r - wk1r * x0i; + y5i = wk1i * x0i + wk1r * x0r; + x0r = a[6] - a[15]; + x0i = a[7] + a[14]; + y6r = wk1i * x0r - wk1r * x0i; + y6i = wk1i * x0i + wk1r * x0r; + x0r = a[6] + a[15]; + x0i = a[7] - a[14]; + y7r = wk1r * x0r - wk1i * x0i; + y7i = wk1r * x0i + wk1i * x0r; + x0r = y0r + y2r; + x0i = y0i + y2i; + x1r = y4r + y6r; + x1i = y4i + y6i; + a[0] = x0r + x1r; + a[1] = x0i + x1i; + a[2] = x0r - x1r; + a[3] = x0i - x1i; + x0r = y0r - y2r; + x0i = y0i - y2i; + x1r = y4r - y6r; + x1i = y4i - y6i; + a[4] = x0r - x1i; + a[5] = x0i + x1r; + a[6] = x0r + x1i; + a[7] = x0i - x1r; + x0r = y1r - y3i; + x0i = y1i + y3r; + x1r = y5r - y7r; + x1i = y5i - y7i; + a[8] = x0r + x1r; + a[9] = x0i + x1i; + a[10] = x0r - x1r; + a[11] = x0i - x1i; + x0r = y1r + y3i; + x0i = y1i - y3r; + x1r = y5r + y7r; + x1i = y5i + y7i; + a[12] = x0r - x1i; + a[13] = x0i + x1r; + a[14] = x0r + x1i; + a[15] = x0i - x1r; +} + + +void cftf040(double *a) { + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + x0r = a[0] + a[4]; + x0i = a[1] + a[5]; + x1r = a[0] - a[4]; + x1i = a[1] - a[5]; + x2r = a[2] + a[6]; + x2i = a[3] + a[7]; + x3r = a[2] - a[6]; + x3i = a[3] - a[7]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x1r - x3i; + a[3] = x1i + x3r; + a[4] = x0r - x2r; + a[5] = x0i - x2i; + a[6] = x1r + x3i; + a[7] = x1i - x3r; +} + + +void cftb040(double *a) { + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + x0r = a[0] + a[4]; + x0i = a[1] + a[5]; + x1r = a[0] - a[4]; + x1i = a[1] - a[5]; + x2r = a[2] + a[6]; + x2i = a[3] + a[7]; + x3r = a[2] - a[6]; + x3i = a[3] - a[7]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x1r + x3i; + a[3] = x1i - x3r; + a[4] = x0r - x2r; + a[5] = x0i - x2i; + a[6] = x1r - x3i; + a[7] = x1i + x3r; +} + + +void cftx020(double *a) { + double x0r, x0i; + + x0r = a[0] - a[2]; + x0i = a[1] - a[3]; + a[0] += a[2]; + a[1] += a[3]; + a[2] = x0r; + a[3] = x0i; +} + + +void rftfsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr, xi, yr, yi; + + m = n >> 1; + ks = 2 * nc / m; + kk = 0; + for (j = 2; j < m; j += 2) { + k = n - j; + kk += ks; + wkr = 0.5 - c[nc - kk]; + wki = c[kk]; + xr = a[j] - a[k]; + xi = a[j + 1] + a[k + 1]; + yr = wkr * xr - wki * xi; + yi = wkr * xi + wki * xr; + a[j] -= yr; + a[j + 1] -= yi; + a[k] += yr; + a[k + 1] -= yi; + } +} + + +void rftbsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr, xi, yr, yi; + + m = n >> 1; + ks = 2 * nc / m; + kk = 0; + for (j = 2; j < m; j += 2) { + k = n - j; + kk += ks; + wkr = 0.5 - c[nc - kk]; + wki = c[kk]; + xr = a[j] - a[k]; + xi = a[j + 1] + a[k + 1]; + yr = wkr * xr + wki * xi; + yi = wkr * xi - wki * xr; + a[j] -= yr; + a[j + 1] -= yi; + a[k] += yr; + a[k + 1] -= yi; + } +} + + +void dctsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr; + + m = n >> 1; + ks = nc / n; + kk = 0; + for (j = 1; j < m; j++) { + k = n - j; + kk += ks; + wkr = c[kk] - c[nc - kk]; + wki = c[kk] + c[nc - kk]; + xr = wki * a[j] - wkr * a[k]; + a[j] = wkr * a[j] + wki * a[k]; + a[k] = xr; + } + a[m] *= c[0]; +} + + +void dstsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr; + + m = n >> 1; + ks = nc / n; + kk = 0; + for (j = 1; j < m; j++) { + k = n - j; + kk += ks; + wkr = c[kk] - c[nc - kk]; + wki = c[kk] + c[nc - kk]; + xr = wki * a[k] - wkr * a[j]; + a[k] = wkr * a[k] + wki * a[j]; + a[j] = xr; + } + a[m] *= c[0]; +} diff --git a/speechx/speechx/frontend/audio/frontend_itf.h b/runtime/engine/common/frontend/frontend_itf.h similarity index 88% rename from speechx/speechx/frontend/audio/frontend_itf.h rename to runtime/engine/common/frontend/frontend_itf.h index 7913cc7c..57186ec4 100644 --- a/speechx/speechx/frontend/audio/frontend_itf.h +++ b/runtime/engine/common/frontend/frontend_itf.h @@ -15,20 +15,20 @@ #pragma once #include "base/basic_types.h" -#include "kaldi/matrix/kaldi-vector.h" +#include "matrix/kaldi-vector.h" namespace ppspeech { class FrontendInterface { public: // Feed inputs: features(2D saved in 1D) or waveforms(1D). - virtual void Accept(const kaldi::VectorBase& inputs) = 0; + virtual void Accept(const std::vector& inputs) = 0; // Fetch processed data: features or waveforms. // For features(2D saved in 1D), the Matrix is squashed into Vector, // the length of output = feature_row * feature_dim. // For waveforms(1D), samples saved in vector. - virtual bool Read(kaldi::Vector* outputs) = 0; + virtual bool Read(std::vector* outputs) = 0; // Dim is the feature dim. For waveforms(1D), Dim is zero; else is specific, // e.g 80 for fbank. diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/runtime/engine/common/frontend/linear_spectrogram.cc similarity index 100% rename from speechx/speechx/frontend/audio/linear_spectrogram.cc rename to runtime/engine/common/frontend/linear_spectrogram.cc diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/runtime/engine/common/frontend/linear_spectrogram.h similarity index 100% rename from speechx/speechx/frontend/audio/linear_spectrogram.h rename to runtime/engine/common/frontend/linear_spectrogram.h diff --git a/runtime/engine/common/frontend/mel-computations.cc b/runtime/engine/common/frontend/mel-computations.cc new file mode 100644 index 00000000..3998af22 --- /dev/null +++ b/runtime/engine/common/frontend/mel-computations.cc @@ -0,0 +1,277 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/mel-computations.cc + +#include "frontend/mel-computations.h" + +#include +#include + +#include "frontend/feature-window.h" + +namespace knf { + +std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) { + os << opts.ToString(); + return os; +} + +float MelBanks::VtlnWarpFreq( + float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. + float vtln_high_cutoff, + float low_freq, // upper+lower frequency cutoffs in mel computation + float high_freq, + float vtln_warp_factor, + float freq) { + /// This computes a VTLN warping function that is not the same as HTK's one, + /// but has similar inputs (this function has the advantage of never + /// producing + /// empty bins). + + /// This function computes a warp function F(freq), defined between low_freq + /// and high_freq inclusive, with the following properties: + /// F(low_freq) == low_freq + /// F(high_freq) == high_freq + /// The function is continuous and piecewise linear with two inflection + /// points. + /// The lower inflection point (measured in terms of the unwarped + /// frequency) is at frequency l, determined as described below. + /// The higher inflection point is at a frequency h, determined as + /// described below. + /// If l <= f <= h, then F(f) = f/vtln_warp_factor. + /// If the higher inflection point (measured in terms of the unwarped + /// frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + /// Since (by the last point) F(h) == h/vtln_warp_factor, then + /// max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + /// h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + /// = vtln_high_cutoff * min(1, vtln_warp_factor). + /// If the lower inflection point (measured in terms of the unwarped + /// frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + /// This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + /// = vtln_low_cutoff * max(1, vtln_warp_factor) + + if (freq < low_freq || freq > high_freq) + return freq; // in case this gets called + // for out-of-range frequencies, just return the freq. + + CHECK_GT(vtln_low_cutoff, low_freq); + CHECK_LT(vtln_high_cutoff, high_freq); + + float one = 1.0f; + float l = vtln_low_cutoff * std::max(one, vtln_warp_factor); + float h = vtln_high_cutoff * std::min(one, vtln_warp_factor); + float scale = 1.0f / vtln_warp_factor; + float Fl = scale * l; // F(l); + float Fh = scale * h; // F(h); + CHECK(l > low_freq && h < high_freq); + // slope of left part of the 3-piece linear function + float scale_left = (Fl - low_freq) / (l - low_freq); + // [slope of center part is just "scale"] + + // slope of right part of the 3-piece linear function + float scale_right = (high_freq - Fh) / (high_freq - h); + + if (freq < l) { + return low_freq + scale_left * (freq - low_freq); + } else if (freq < h) { + return scale * freq; + } else { // freq >= h + return high_freq + scale_right * (freq - high_freq); + } +} + +float MelBanks::VtlnWarpMelFreq( + float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. + float vtln_high_cutoff, + float low_freq, // upper+lower frequency cutoffs in mel computation + float high_freq, + float vtln_warp_factor, + float mel_freq) { + return MelScale(VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + vtln_warp_factor, + InverseMelScale(mel_freq))); +} + +MelBanks::MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + float vtln_warp_factor) + : htk_mode_(opts.htk_mode) { + int32_t num_bins = opts.num_bins; + if (num_bins < 3) LOG(FATAL) << "Must have at least 3 mel bins"; + + float sample_freq = frame_opts.samp_freq; + int32_t window_length_padded = frame_opts.PaddedWindowSize(); + CHECK_EQ(window_length_padded % 2, 0); + + int32_t num_fft_bins = window_length_padded / 2; + float nyquist = 0.5f * sample_freq; + + float low_freq = opts.low_freq, high_freq; + if (opts.high_freq > 0.0f) + high_freq = opts.high_freq; + else + high_freq = nyquist + opts.high_freq; + + if (low_freq < 0.0f || low_freq >= nyquist || high_freq <= 0.0f || + high_freq > nyquist || high_freq <= low_freq) { + LOG(FATAL) << "Bad values in options: low-freq " << low_freq + << " and high-freq " << high_freq << " vs. nyquist " + << nyquist; + } + + float fft_bin_width = sample_freq / window_length_padded; + // fft-bin width [think of it as Nyquist-freq / half-window-length] + + float mel_low_freq = MelScale(low_freq); + float mel_high_freq = MelScale(high_freq); + + debug_ = opts.debug_mel; + + // divide by num_bins+1 in next line because of end-effects where the bins + // spread out to the sides. + float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1); + + float vtln_low = opts.vtln_low, vtln_high = opts.vtln_high; + if (vtln_high < 0.0f) { + vtln_high += nyquist; + } + + if (vtln_warp_factor != 1.0f && + (vtln_low < 0.0f || vtln_low <= low_freq || vtln_low >= high_freq || + vtln_high <= 0.0f || vtln_high >= high_freq || + vtln_high <= vtln_low)) { + LOG(FATAL) << "Bad values in options: vtln-low " << vtln_low + << " and vtln-high " << vtln_high << ", versus " + << "low-freq " << low_freq << " and high-freq " << high_freq; + } + + bins_.resize(num_bins); + center_freqs_.resize(num_bins); + + for (int32_t bin = 0; bin < num_bins; ++bin) { + float left_mel = mel_low_freq + bin * mel_freq_delta, + center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, + right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; + + if (vtln_warp_factor != 1.0f) { + left_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + left_mel); + center_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + center_mel); + right_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + right_mel); + } + center_freqs_[bin] = InverseMelScale(center_mel); + + // this_bin will be a vector of coefficients that is only + // nonzero where this mel bin is active. + std::vector this_bin(num_fft_bins); + + int32_t first_index = -1, last_index = -1; + for (int32_t i = 0; i < num_fft_bins; ++i) { + float freq = (fft_bin_width * i); // Center frequency of this fft + // bin. + float mel = MelScale(freq); + if (mel > left_mel && mel < right_mel) { + float weight; + if (mel <= center_mel) + weight = (mel - left_mel) / (center_mel - left_mel); + else + weight = (right_mel - mel) / (right_mel - center_mel); + this_bin[i] = weight; + if (first_index == -1) first_index = i; + last_index = i; + } + } + CHECK(first_index != -1 && last_index >= first_index && + "You may have set num_mel_bins too large."); + + bins_[bin].first = first_index; + int32_t size = last_index + 1 - first_index; + bins_[bin].second.insert(bins_[bin].second.end(), + this_bin.begin() + first_index, + this_bin.begin() + first_index + size); + + // Replicate a bug in HTK, for testing purposes. + if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0f) { + bins_[bin].second[0] = 0.0; + } + } // for (int32_t bin = 0; bin < num_bins; ++bin) { + + if (debug_) { + std::ostringstream os; + for (size_t i = 0; i < bins_.size(); i++) { + os << "bin " << i << ", offset = " << bins_[i].first << ", vec = "; + for (auto k : bins_[i].second) os << k << ", "; + os << "\n"; + } + LOG(INFO) << os.str(); + } +} + +// "power_spectrum" contains fft energies. +void MelBanks::Compute(const float *power_spectrum, + float *mel_energies_out) const { + int32_t num_bins = bins_.size(); + + for (int32_t i = 0; i < num_bins; i++) { + int32_t offset = bins_[i].first; + const auto &v = bins_[i].second; + float energy = 0; + for (int32_t k = 0; k != v.size(); ++k) { + energy += v[k] * power_spectrum[k + offset]; + } + + // HTK-like flooring- for testing purposes (we prefer dither) + if (htk_mode_ && energy < 1.0) { + energy = 1.0; + } + + mel_energies_out[i] = energy; + + // The following assert was added due to a problem with OpenBlas that + // we had at one point (it was a bug in that library). Just to detect + // it early. + CHECK_EQ(energy, energy); // check that energy is not nan + } + + if (debug_) { + fprintf(stderr, "MEL BANKS:\n"); + for (int32_t i = 0; i < num_bins; i++) + fprintf(stderr, " %f", mel_energies_out[i]); + fprintf(stderr, "\n"); + } +} + +} // namespace knf diff --git a/runtime/engine/common/frontend/mel-computations.h b/runtime/engine/common/frontend/mel-computations.h new file mode 100644 index 00000000..2f9938bc --- /dev/null +++ b/runtime/engine/common/frontend/mel-computations.h @@ -0,0 +1,120 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// This file is copied/modified from kaldi/src/feat/mel-computations.h +#ifndef KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ +#define KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ + +#include +#include + +#include "frontend/feature-window.h" + +namespace knf { + +struct MelBanksOptions { + int32_t num_bins = 25; // e.g. 25; number of triangular bins + float low_freq = 20; // e.g. 20; lower frequency cutoff + + // an upper frequency cutoff; 0 -> no cutoff, negative + // ->added to the Nyquist frequency to get the cutoff. + float high_freq = 0; + + float vtln_low = 100; // vtln lower cutoff of warping function. + + // vtln upper cutoff of warping function: if negative, added + // to the Nyquist frequency to get the cutoff. + float vtln_high = -500; + + bool debug_mel = false; + // htk_mode is a "hidden" config, it does not show up on command line. + // Enables more exact compatibility with HTK, for testing purposes. Affects + // mel-energy flooring and reproduces a bug in HTK. + bool htk_mode = false; + + std::string ToString() const { + std::ostringstream os; + os << "num_bins: " << num_bins << "\n"; + os << "low_freq: " << low_freq << "\n"; + os << "high_freq: " << high_freq << "\n"; + os << "vtln_low: " << vtln_low << "\n"; + os << "vtln_high: " << vtln_high << "\n"; + os << "debug_mel: " << debug_mel << "\n"; + os << "htk_mode: " << htk_mode << "\n"; + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts); + +class MelBanks { + public: + static inline float InverseMelScale(float mel_freq) { + return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f); + } + + static inline float MelScale(float freq) { + return 1127.0f * logf(1.0f + freq / 700.0f); + } + + static float VtlnWarpFreq( + float vtln_low_cutoff, + float vtln_high_cutoff, // discontinuities in warp func + float low_freq, + float high_freq, // upper+lower frequency cutoffs in + // the mel computation + float vtln_warp_factor, + float freq); + + static float VtlnWarpMelFreq(float vtln_low_cutoff, + float vtln_high_cutoff, + float low_freq, + float high_freq, + float vtln_warp_factor, + float mel_freq); + + // TODO(fangjun): Remove vtln_warp_factor + MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + float vtln_warp_factor); + + /// Compute Mel energies (note: not log energies). + /// At input, "fft_energies" contains the FFT energies (not log). + /// + /// @param fft_energies 1-D array of size num_fft_bins/2+1 + /// @param mel_energies_out 1-D array of size num_mel_bins + void Compute(const float *fft_energies, float *mel_energies_out) const; + + int32_t NumBins() const { return bins_.size(); } + + private: + // center frequencies of bins, numbered from 0 ... num_bins-1. + // Needed by GetCenterFreqs(). + std::vector center_freqs_; + + // the "bins_" vector is a vector, one for each bin, of a pair: + // (the first nonzero fft-bin), (the vector of weights). + std::vector>> bins_; + + // TODO(fangjun): Remove debug_ and htk_mode_ + bool debug_; + bool htk_mode_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ diff --git a/speechx/speechx/frontend/audio/normalizer.h b/runtime/engine/common/frontend/normalizer.h similarity index 90% rename from speechx/speechx/frontend/audio/normalizer.h rename to runtime/engine/common/frontend/normalizer.h index dcf721dd..5a6ca573 100644 --- a/speechx/speechx/frontend/audio/normalizer.h +++ b/runtime/engine/common/frontend/normalizer.h @@ -14,5 +14,4 @@ #pragma once -#include "frontend/audio/cmvn.h" -#include "frontend/audio/db_norm.h" \ No newline at end of file +#include "frontend/cmvn.h" \ No newline at end of file diff --git a/runtime/engine/common/frontend/rfft.cc b/runtime/engine/common/frontend/rfft.cc new file mode 100644 index 00000000..9ce6a172 --- /dev/null +++ b/runtime/engine/common/frontend/rfft.cc @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/rfft.h" + +#include +#include +#include + +#include "base/log.h" + +// see fftsg.c +#ifdef __cplusplus +extern "C" void rdft(int n, int isgn, double *a, int *ip, double *w); +#else +void rdft(int n, int isgn, double *a, int *ip, double *w); +#endif + +namespace knf { +class Rfft::RfftImpl { + public: + explicit RfftImpl(int32_t n) : n_(n), ip_(2 + std::sqrt(n / 2)), w_(n / 2) { + CHECK_EQ(n & (n - 1), 0); + } + + void Compute(float *in_out) { + std::vector d(in_out, in_out + n_); + + Compute(d.data()); + + std::copy(d.begin(), d.end(), in_out); + } + + void Compute(double *in_out) { + // 1 means forward fft + rdft(n_, 1, in_out, ip_.data(), w_.data()); + } + + private: + int32_t n_; + std::vector ip_; + std::vector w_; +}; + +Rfft::Rfft(int32_t n) : impl_(std::make_unique(n)) {} + +Rfft::~Rfft() = default; + +void Rfft::Compute(float *in_out) { impl_->Compute(in_out); } +void Rfft::Compute(double *in_out) { impl_->Compute(in_out); } + +} // namespace knf diff --git a/runtime/engine/common/frontend/rfft.h b/runtime/engine/common/frontend/rfft.h new file mode 100644 index 00000000..52da2626 --- /dev/null +++ b/runtime/engine/common/frontend/rfft.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef KALDI_NATIVE_FBANK_CSRC_RFFT_H_ +#define KALDI_NATIVE_FBANK_CSRC_RFFT_H_ + +#include + +namespace knf { + +// n-point Real discrete Fourier transform +// where n is a power of 2. n >= 2 +// +// R[k] = sum_j=0^n-1 in[j]*cos(2*pi*j*k/n), 0<=k<=n/2 +// I[k] = sum_j=0^n-1 in[j]*sin(2*pi*j*k/n), 0 impl_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_RFFT_H_ diff --git a/runtime/engine/common/frontend/wave-reader.cc b/runtime/engine/common/frontend/wave-reader.cc new file mode 100644 index 00000000..e94aafef --- /dev/null +++ b/runtime/engine/common/frontend/wave-reader.cc @@ -0,0 +1,376 @@ +// feat/wave-reader.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek +// 2013 Florent Masson +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/wave-reader.h" + +#include +#include +#include +#include +#include + +#include "base/kaldi-error.h" +#include "base/kaldi-utils.h" + +namespace kaldi { + +// A utility class for reading wave header. +struct WaveHeaderReadGofer { + std::istream &is; + bool swap; + char tag[5]; + + WaveHeaderReadGofer(std::istream &is) : is(is), swap(false) { + memset(tag, '\0', sizeof tag); + } + + void Expect4ByteTag(const char *expected) { + is.read(tag, 4); + if (is.fail()) + KALDI_ERR << "WaveData: expected " << expected + << ", failed to read anything"; + if (strcmp(tag, expected)) + KALDI_ERR << "WaveData: expected " << expected << ", got " << tag; + } + + void Read4ByteTag() { + is.read(tag, 4); + if (is.fail()) + KALDI_ERR << "WaveData: expected 4-byte chunk-name, got read error"; + } + + uint32 ReadUint32() { + union { + char result[4]; + uint32 ans; + } u; + is.read(u.result, 4); + if (swap) KALDI_SWAP4(u.result); + if (is.fail()) + KALDI_ERR << "WaveData: unexpected end of file or read error"; + return u.ans; + } + + uint16 ReadUint16() { + union { + char result[2]; + int16 ans; + } u; + is.read(u.result, 2); + if (swap) KALDI_SWAP2(u.result); + if (is.fail()) + KALDI_ERR << "WaveData: unexpected end of file or read error"; + return u.ans; + } +}; + +static void WriteUint32(std::ostream &os, int32 i) { + union { + char buf[4]; + int i; + } u; + u.i = i; +#ifdef __BIG_ENDIAN__ + KALDI_SWAP4(u.buf); +#endif + os.write(u.buf, 4); + if (os.fail()) KALDI_ERR << "WaveData: error writing to stream."; +} + +static void WriteUint16(std::ostream &os, int16 i) { + union { + char buf[2]; + int16 i; + } u; + u.i = i; +#ifdef __BIG_ENDIAN__ + KALDI_SWAP2(u.buf); +#endif + os.write(u.buf, 2); + if (os.fail()) KALDI_ERR << "WaveData: error writing to stream."; +} + +void WaveInfo::Read(std::istream &is) { + WaveHeaderReadGofer reader(is); + reader.Read4ByteTag(); + if (strcmp(reader.tag, "RIFF") == 0) + reverse_bytes_ = false; + else if (strcmp(reader.tag, "RIFX") == 0) + reverse_bytes_ = true; + else + KALDI_ERR << "WaveData: expected RIFF or RIFX, got " << reader.tag; + +#ifdef __BIG_ENDIAN__ + reverse_bytes_ = !reverse_bytes_; +#endif + reader.swap = reverse_bytes_; + + uint32 riff_chunk_size = reader.ReadUint32(); + reader.Expect4ByteTag("WAVE"); + + uint32 riff_chunk_read = 0; + riff_chunk_read += 4; // WAVE included in riff_chunk_size. + + // Possibly skip any RIFF tags between 'WAVE' and 'fmt '. + // Apple devices produce a filler tag 'JUNK' for memory alignment. + reader.Read4ByteTag(); + riff_chunk_read += 4; + while (strcmp(reader.tag, "fmt ") != 0) { + uint32 filler_size = reader.ReadUint32(); + riff_chunk_read += 4; + for (uint32 i = 0; i < filler_size; i++) { + is.get(); // read 1 byte, + } + riff_chunk_read += filler_size; + // get next RIFF tag, + reader.Read4ByteTag(); + riff_chunk_read += 4; + } + + KALDI_ASSERT(strcmp(reader.tag, "fmt ") == 0); + uint32 subchunk1_size = reader.ReadUint32(); + uint16 audio_format = reader.ReadUint16(); + num_channels_ = reader.ReadUint16(); + uint32 sample_rate = reader.ReadUint32(), byte_rate = reader.ReadUint32(), + block_align = reader.ReadUint16(), + bits_per_sample = reader.ReadUint16(); + samp_freq_ = static_cast(sample_rate); + + uint32 fmt_chunk_read = 16; + if (audio_format == 1) { + if (subchunk1_size < 16) { + KALDI_ERR << "WaveData: expect PCM format data to have fmt chunk " + << "of at least size 16."; + } + } else if (audio_format == 0xFFFE) { // WAVE_FORMAT_EXTENSIBLE + uint16 extra_size = reader.ReadUint16(); + if (subchunk1_size < 40 || extra_size < 22) { + KALDI_ERR + << "WaveData: malformed WAVE_FORMAT_EXTENSIBLE format data."; + } + reader.ReadUint16(); // Unused for PCM. + reader.ReadUint32(); // Channel map: we do not care. + uint32 guid1 = reader.ReadUint32(), guid2 = reader.ReadUint32(), + guid3 = reader.ReadUint32(), guid4 = reader.ReadUint32(); + fmt_chunk_read = 40; + + // Support only KSDATAFORMAT_SUBTYPE_PCM for now. Interesting formats: + // ("00000001-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_PCM) + // ("00000003-0000-0010-8000-00aa00389b71", + // KSDATAFORMAT_SUBTYPE_IEEE_FLOAT) + // ("00000006-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_ALAW) + // ("00000007-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_MULAW) + if (guid1 != 0x00000001 || guid2 != 0x00100000 || guid3 != 0xAA000080 || + guid4 != 0x719B3800) { + KALDI_ERR << "WaveData: unsupported WAVE_FORMAT_EXTENSIBLE format."; + } + } else { + KALDI_ERR << "WaveData: can read only PCM data, format id in file is: " + << audio_format; + } + + for (uint32 i = fmt_chunk_read; i < subchunk1_size; ++i) + is.get(); // use up extra data. + + if (num_channels_ == 0) KALDI_ERR << "WaveData: no channels present"; + if (bits_per_sample != 16) + KALDI_ERR << "WaveData: unsupported bits_per_sample = " + << bits_per_sample; + if (byte_rate != sample_rate * bits_per_sample / 8 * num_channels_) + KALDI_ERR << "Unexpected byte rate " << byte_rate << " vs. " + << sample_rate << " * " << (bits_per_sample / 8) << " * " + << num_channels_; + if (block_align != num_channels_ * bits_per_sample / 8) + KALDI_ERR << "Unexpected block_align: " << block_align << " vs. " + << num_channels_ << " * " << (bits_per_sample / 8); + + riff_chunk_read += 4 + subchunk1_size; + // size of what we just read, 4 for subchunk1_size + subchunk1_size itself. + + // We support an optional "fact" chunk (which is useless but which + // we encountered), and then a single "data" chunk. + + reader.Read4ByteTag(); + riff_chunk_read += 4; + + // Skip any subchunks between "fmt" and "data". Usually there will + // be a single "fact" subchunk, but on Windows there can also be a + // "list" subchunk. + while (strcmp(reader.tag, "data") != 0) { + // We will just ignore the data in these chunks. + uint32 chunk_sz = reader.ReadUint32(); + if (chunk_sz != 4 && strcmp(reader.tag, "fact") == 0) + KALDI_WARN << "Expected fact chunk to be 4 bytes long."; + for (uint32 i = 0; i < chunk_sz; i++) is.get(); + riff_chunk_read += + 4 + chunk_sz; // for chunk_sz (4) + chunk contents (chunk-sz) + + // Now read the next chunk name. + reader.Read4ByteTag(); + riff_chunk_read += 4; + } + + KALDI_ASSERT(strcmp(reader.tag, "data") == 0); + uint32 data_chunk_size = reader.ReadUint32(); + riff_chunk_read += 4; + + // Figure out if the file is going to be read to the end. Values as + // observed in the wild: + bool is_stream_mode = + riff_chunk_size == 0 || riff_chunk_size == 0xFFFFFFFF || + data_chunk_size == 0 || data_chunk_size == 0xFFFFFFFF || + data_chunk_size == 0x7FFFF000; // This value is used by SoX. + + if (is_stream_mode) + KALDI_VLOG(1) << "Read in RIFF chunk size: " << riff_chunk_size + << ", data chunk size: " << data_chunk_size + << ". Assume 'stream mode' (reading data to EOF)."; + + if (!is_stream_mode && std::abs(static_cast(riff_chunk_read) + + static_cast(data_chunk_size) - + static_cast(riff_chunk_size)) > 1) { + // We allow the size to be off by one without warning, because there is + // a + // weirdness in the format of RIFF files that means that the input may + // sometimes be padded with 1 unused byte to make the total size even. + KALDI_WARN << "Expected " << riff_chunk_size + << " bytes in RIFF chunk, but " + << "after first data block there will be " << riff_chunk_read + << " + " << data_chunk_size << " bytes " + << "(we do not support reading multiple data chunks)."; + } + + if (is_stream_mode) + samp_count_ = -1; + else + samp_count_ = data_chunk_size / block_align; +} + +void WaveData::Read(std::istream &is) { + const uint32 kBlockSize = 1024 * 1024; + + WaveInfo header; + header.Read(is); + + data_.Resize(0, 0); // clear the data. + samp_freq_ = header.SampFreq(); + + std::vector buffer; + uint32 bytes_to_go = header.IsStreamed() ? kBlockSize : header.DataBytes(); + + // Once in a while header.DataBytes() will report an insane value; + // read the file to the end + while (is && bytes_to_go > 0) { + uint32 block_bytes = std::min(bytes_to_go, kBlockSize); + uint32 offset = buffer.size(); + buffer.resize(offset + block_bytes); + is.read(&buffer[offset], block_bytes); + uint32 bytes_read = is.gcount(); + buffer.resize(offset + bytes_read); + if (!header.IsStreamed()) bytes_to_go -= bytes_read; + } + + if (is.bad()) KALDI_ERR << "WaveData: file read error"; + + if (buffer.size() == 0) KALDI_ERR << "WaveData: empty file (no data)"; + + if (!header.IsStreamed() && buffer.size() < header.DataBytes()) { + KALDI_WARN << "Expected " << header.DataBytes() + << " bytes of wave data, " + << "but read only " << buffer.size() << " bytes. " + << "Truncated file?"; + } + + uint16 *data_ptr = reinterpret_cast(&buffer[0]); + + // The matrix is arranged row per channel, column per sample. + data_.Resize(header.NumChannels(), buffer.size() / header.BlockAlign()); + for (uint32 i = 0; i < data_.NumCols(); ++i) { + for (uint32 j = 0; j < data_.NumRows(); ++j) { + int16 k = *data_ptr++; + if (header.ReverseBytes()) KALDI_SWAP2(k); + data_(j, i) = k; + } + } +} + + +// Write 16-bit PCM. + +// note: the WAVE chunk contains 2 subchunks. +// +// subchunk2size = data.NumRows() * data.NumCols() * 2. + + +void WaveData::Write(std::ostream &os) const { + os << "RIFF"; + if (data_.NumRows() == 0) + KALDI_ERR << "Error: attempting to write empty WAVE file"; + + int32 num_chan = data_.NumRows(), num_samp = data_.NumCols(), + bytes_per_samp = 2; + + int32 subchunk2size = (num_chan * num_samp * bytes_per_samp); + int32 chunk_size = 36 + subchunk2size; + WriteUint32(os, chunk_size); + os << "WAVE"; + os << "fmt "; + WriteUint32(os, 16); + WriteUint16(os, 1); + WriteUint16(os, num_chan); + KALDI_ASSERT(samp_freq_ > 0); + WriteUint32(os, static_cast(samp_freq_)); + WriteUint32(os, static_cast(samp_freq_) * num_chan * bytes_per_samp); + WriteUint16(os, num_chan * bytes_per_samp); + WriteUint16(os, 8 * bytes_per_samp); + os << "data"; + WriteUint32(os, subchunk2size); + + const BaseFloat *data_ptr = data_.Data(); + int32 stride = data_.Stride(); + + int num_clipped = 0; + for (int32 i = 0; i < num_samp; i++) { + for (int32 j = 0; j < num_chan; j++) { + int32 elem = static_cast(trunc(data_ptr[j * stride + i])); + int16 elem_16 = static_cast(elem); + if (elem < std::numeric_limits::min()) { + elem_16 = std::numeric_limits::min(); + ++num_clipped; + } else if (elem > std::numeric_limits::max()) { + elem_16 = std::numeric_limits::max(); + ++num_clipped; + } +#ifdef __BIG_ENDIAN__ + KALDI_SWAP2(elem_16); +#endif + os.write(reinterpret_cast(&elem_16), 2); + } + } + if (os.fail()) KALDI_ERR << "Error writing wave data to stream."; + if (num_clipped > 0) + KALDI_WARN << "WARNING: clipped " << num_clipped + << " samples out of total " << num_chan * num_samp + << ". Reduce volume?"; +} + + +} // end namespace kaldi diff --git a/runtime/engine/common/frontend/wave-reader.h b/runtime/engine/common/frontend/wave-reader.h new file mode 100644 index 00000000..6cd471b8 --- /dev/null +++ b/runtime/engine/common/frontend/wave-reader.h @@ -0,0 +1,248 @@ +// feat/wave-reader.h + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation +// 2013 Florent Masson +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +/* +// THE WAVE FORMAT IS SPECIFIED IN: +// https:// ccrma.stanford.edu/courses/422/projects/WaveFormat/ +// +// +// +// RIFF +// | +// WAVE +// | \ \ \ +// fmt_ data ... data +// +// +// Riff is a general container, which usually contains one WAVE chunk +// each WAVE chunk has header sub-chunk 'fmt_' +// and one or more data sub-chunks 'data' +// +// [Note from Dan: to say that the wave format was ever "specified" anywhere is +// not quite right. The guy who invented the wave format attempted to create +// a formal specification but it did not completely make sense. And there +// doesn't seem to be a consensus on what makes a valid wave file, +// particularly where the accuracy of header information is concerned.] +*/ + + +#ifndef KALDI_FEAT_WAVE_READER_H_ +#define KALDI_FEAT_WAVE_READER_H_ + +#include + +#include "base/kaldi-types.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/kaldi-vector.h" + + +namespace kaldi { + +/// For historical reasons, we scale waveforms to the range +/// (2^15-1)*[-1, 1], not the usual default DSP range [-1, 1]. +const BaseFloat kWaveSampleMax = 32768.0; + +/// This class reads and hold wave file header information. +class WaveInfo { + public: + WaveInfo() + : samp_freq_(0), samp_count_(0), num_channels_(0), reverse_bytes_(0) {} + + /// Is stream size unknown? Duration and SampleCount not valid if true. + bool IsStreamed() const { return samp_count_ < 0; } + + /// Sample frequency, Hz. + BaseFloat SampFreq() const { return samp_freq_; } + + /// Number of samples in stream. Invalid if IsStreamed() is true. + uint32 SampleCount() const { return samp_count_; } + + /// Approximate duration, seconds. Invalid if IsStreamed() is true. + BaseFloat Duration() const { return samp_count_ / samp_freq_; } + + /// Number of channels, 1 to 16. + int32 NumChannels() const { return num_channels_; } + + /// Bytes per sample. + size_t BlockAlign() const { return 2 * num_channels_; } + + /// Wave data bytes. Invalid if IsStreamed() is true. + size_t DataBytes() const { return samp_count_ * BlockAlign(); } + + /// Is data file byte order different from machine byte order? + bool ReverseBytes() const { return reverse_bytes_; } + + /// 'is' should be opened in binary mode. Read() will throw on error. + /// On success 'is' will be positioned at the beginning of wave data. + void Read(std::istream &is); + + private: + BaseFloat samp_freq_; + int32 samp_count_; // 0 if empty, -1 if undefined length. + uint8 num_channels_; + bool reverse_bytes_; // File endianness differs from host. +}; + +/// This class's purpose is to read in Wave files. +class WaveData { + public: + WaveData(BaseFloat samp_freq, const MatrixBase &data) + : data_(data), samp_freq_(samp_freq) {} + + WaveData() : samp_freq_(0.0) {} + + /// Read() will throw on error. It's valid to call Read() more than once-- + /// in this case it will destroy what was there before. + /// "is" should be opened in binary mode. + void Read(std::istream &is); + + /// Write() will throw on error. os should be opened in binary mode. + void Write(std::ostream &os) const; + + // This function returns the wave data-- it's in a matrix + // because there may be multiple channels. In the normal case + // there's just one channel so Data() will have one row. + const Matrix &Data() const { return data_; } + + BaseFloat SampFreq() const { return samp_freq_; } + + // Returns the duration in seconds + BaseFloat Duration() const { return data_.NumCols() / samp_freq_; } + + void CopyFrom(const WaveData &other) { + samp_freq_ = other.samp_freq_; + data_.CopyFromMat(other.data_); + } + + void Clear() { + data_.Resize(0, 0); + samp_freq_ = 0.0; + } + + void Swap(WaveData *other) { + data_.Swap(&(other->data_)); + std::swap(samp_freq_, other->samp_freq_); + } + + private: + static const uint32 kBlockSize = 1024 * 1024; // Use 1M bytes. + Matrix data_; + BaseFloat samp_freq_; +}; + + +// Holder class for .wav files that enables us to read (but not write) .wav +// files. c.f. util/kaldi-holder.h we don't use the KaldiObjectHolder template +// because we don't want to check for the \0B binary header. We could have faked +// it by pretending to read in the wave data in text mode after failing to find +// the \0B header, but that would have been a little ugly. +class WaveHolder { + public: + typedef WaveData T; + + static bool Write(std::ostream &os, bool binary, const T &t) { + // We don't write the binary-mode header here [always binary]. + if (!binary) + KALDI_ERR << "Wave data can only be written in binary mode."; + try { + t.Write(os); // throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveHolder object (writing). " + << e.what(); + return false; // write failure. + } + } + void Copy(const T &t) { t_.CopyFrom(t); } + + static bool IsReadInBinary() { return true; } + + void Clear() { t_.Clear(); } + + T &Value() { return t_; } + + WaveHolder &operator=(const WaveHolder &other) { + t_.CopyFrom(other.t_); + return *this; + } + WaveHolder(const WaveHolder &other) : t_(other.t_) {} + + WaveHolder() {} + + bool Read(std::istream &is) { + // We don't look for the binary-mode header here [always binary] + try { + t_.Read(is); // Throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveHolder::Read(). " + << e.what(); + return false; + } + } + + void Swap(WaveHolder *other) { t_.Swap(&(other->t_)); } + + bool ExtractRange(const WaveHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + T t_; +}; + +// This is like WaveHolder but when you just want the metadata- +// it leaves the actual data undefined, it doesn't read it. +class WaveInfoHolder { + public: + typedef WaveInfo T; + + void Clear() { info_ = WaveInfo(); } + void Swap(WaveInfoHolder *other) { std::swap(info_, other->info_); } + T &Value() { return info_; } + static bool IsReadInBinary() { return true; } + + bool Read(std::istream &is) { + try { + info_.Read(is); // Throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveInfoHolder::Read(). " + << e.what(); + return false; + } + } + + bool ExtractRange(const WaveInfoHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + WaveInfo info_; +}; + + +} // namespace kaldi + +#endif // KALDI_FEAT_WAVE_READER_H_ diff --git a/runtime/engine/common/matrix/CMakeLists.txt b/runtime/engine/common/matrix/CMakeLists.txt new file mode 100644 index 00000000..a4b34d54 --- /dev/null +++ b/runtime/engine/common/matrix/CMakeLists.txt @@ -0,0 +1,7 @@ + +add_library(kaldi-matrix +kaldi-matrix.cc +kaldi-vector.cc +) + +target_link_libraries(kaldi-matrix kaldi-base) diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h b/runtime/engine/common/matrix/kaldi-matrix-inl.h similarity index 65% rename from speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h rename to runtime/engine/common/matrix/kaldi-matrix-inl.h index c2ff0079..ed18859d 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h +++ b/runtime/engine/common/matrix/kaldi-matrix-inl.h @@ -25,39 +25,41 @@ namespace kaldi { /// Empty constructor -template -Matrix::Matrix(): MatrixBase(NULL, 0, 0, 0) { } - +template +Matrix::Matrix() : MatrixBase(NULL, 0, 0, 0) {} +/* template<> template<> -void MatrixBase::AddVecVec(const float alpha, const VectorBase &ra, const VectorBase &rb); +void MatrixBase::AddVecVec(const float alpha, const VectorBase +&ra, const VectorBase &rb); template<> template<> -void MatrixBase::AddVecVec(const double alpha, const VectorBase &ra, const VectorBase &rb); - -template -inline std::ostream & operator << (std::ostream & os, const MatrixBase & M) { - M.Write(os, false); - return os; +void MatrixBase::AddVecVec(const double alpha, const VectorBase +&ra, const VectorBase &rb); +*/ + +template +inline std::ostream& operator<<(std::ostream& os, const MatrixBase& M) { + M.Write(os, false); + return os; } -template -inline std::istream & operator >> (std::istream & is, Matrix & M) { - M.Read(is, false); - return is; +template +inline std::istream& operator>>(std::istream& is, Matrix& M) { + M.Read(is, false); + return is; } -template -inline std::istream & operator >> (std::istream & is, MatrixBase & M) { - M.Read(is, false); - return is; +template +inline std::istream& operator>>(std::istream& is, MatrixBase& M) { + M.Read(is, false); + return is; } -}// namespace kaldi +} // namespace kaldi #endif // KALDI_MATRIX_KALDI_MATRIX_INL_H_ - diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix.cc b/runtime/engine/common/matrix/kaldi-matrix.cc similarity index 75% rename from speechx/speechx/kaldi/matrix/kaldi-matrix.cc rename to runtime/engine/common/matrix/kaldi-matrix.cc index 85e6fecc..6f65fb0a 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-matrix.cc +++ b/runtime/engine/common/matrix/kaldi-matrix.cc @@ -23,17 +23,9 @@ // limitations under the License. #include "matrix/kaldi-matrix.h" -#include "matrix/sp-matrix.h" -#include "matrix/jama-svd.h" -#include "matrix/jama-eig.h" -#include "matrix/compressed-matrix.h" -#include "matrix/sparse-matrix.h" - -static_assert(int(kaldi::kNoTrans) == int(CblasNoTrans) && int(kaldi::kTrans) == int(CblasTrans), - "kaldi::kNoTrans and kaldi::kTrans must be equal to the appropriate CBLAS library constants!"); namespace kaldi { - +/* template void MatrixBase::Invert(Real *log_det, Real *det_sign, bool inverse_needed) { @@ -174,14 +166,19 @@ void MatrixBase::AddMatMat(const Real alpha, const MatrixBase& B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); if (num_rows_ == 0) return; cblas_Xgemm(alpha, transA, A.data_, A.num_rows_, A.num_cols_, A.stride_, - transB, B.data_, B.stride_, beta, data_, num_rows_, num_cols_, stride_); + transB, B.data_, B.stride_, beta, data_, num_rows_, num_cols_, +stride_); } @@ -199,36 +196,38 @@ void MatrixBase::SetMatMatDivMat(const MatrixBase& A, id = od * (o / i); /// o / i is either zero or "scale". } else { id = od; /// Just imagine the scale was 1.0. This is somehow true in - /// expectation; anyway, this case should basically never happen so it doesn't + /// expectation; anyway, this case should basically never happen so it +doesn't /// really matter. } (*this)(r, c) = id; } } } +*/ +// template +// void MatrixBase::CopyLowerToUpper() { +// KALDI_ASSERT(num_rows_ == num_cols_); +// Real *data = data_; +// MatrixIndexT num_rows = num_rows_, stride = stride_; +// for (int32 i = 0; i < num_rows; i++) +// for (int32 j = 0; j < i; j++) +// data[j * stride + i ] = data[i * stride + j]; +//} -template -void MatrixBase::CopyLowerToUpper() { - KALDI_ASSERT(num_rows_ == num_cols_); - Real *data = data_; - MatrixIndexT num_rows = num_rows_, stride = stride_; - for (int32 i = 0; i < num_rows; i++) - for (int32 j = 0; j < i; j++) - data[j * stride + i ] = data[i * stride + j]; -} +// template +// void MatrixBase::CopyUpperToLower() { +// KALDI_ASSERT(num_rows_ == num_cols_); +// Real *data = data_; +// MatrixIndexT num_rows = num_rows_, stride = stride_; +// for (int32 i = 0; i < num_rows; i++) +// for (int32 j = 0; j < i; j++) +// data[i * stride + j] = data[j * stride + i]; +//} -template -void MatrixBase::CopyUpperToLower() { - KALDI_ASSERT(num_rows_ == num_cols_); - Real *data = data_; - MatrixIndexT num_rows = num_rows_, stride = stride_; - for (int32 i = 0; i < num_rows; i++) - for (int32 j = 0; j < i; j++) - data[i * stride + j] = data[j * stride + i]; -} - +/* template void MatrixBase::SymAddMat2(const Real alpha, const MatrixBase &A, @@ -270,10 +269,14 @@ void MatrixBase::AddMatSmat(const Real alpha, const MatrixBase &B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); // We iterate over the columns of B. @@ -308,10 +311,14 @@ void MatrixBase::AddSmatMat(const Real alpha, const MatrixBase &B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); MatrixIndexT Astride = A.stride_, Bstride = B.stride_, stride = this->stride_, @@ -349,7 +356,8 @@ void MatrixBase::AddSpSp(const Real alpha, const SpMatrix &A_in, // fully (to save work, we used the matrix constructor from SpMatrix). // CblasLeft means A is on the left: C <-- alpha A B + beta C if (sz == 0) return; - cblas_Xsymm(alpha, sz, A.data_, A.stride_, B.data_, B.stride_, beta, data_, stride_); + cblas_Xsymm(alpha, sz, A.data_, A.stride_, B.data_, B.stride_, beta, data_, +stride_); } template @@ -359,13 +367,15 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, if (transA == kNoTrans) { Scale(alpha + 1.0); } else { - KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self (transposed): not symmetric."); + KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self +(transposed): not symmetric."); Real *data = data_; if (alpha == 1.0) { // common case-- handle separately. for (MatrixIndexT row = 0; row < num_rows_; row++) { for (MatrixIndexT col = 0; col < row; col++) { Real *lower = data + (row * stride_) + col, *upper = data + (col - * stride_) + row; + * +stride_) + row; Real sum = *lower + *upper; *lower = *upper = sum; } @@ -375,7 +385,8 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, for (MatrixIndexT row = 0; row < num_rows_; row++) { for (MatrixIndexT col = 0; col < row; col++) { Real *lower = data + (row * stride_) + col, *upper = data + (col - * stride_) + row; + * +stride_) + row; Real lower_tmp = *lower; *lower += alpha * *upper; *upper += alpha * lower_tmp; @@ -397,7 +408,8 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, } else { KALDI_ASSERT(A.num_cols_ == num_rows_ && A.num_rows_ == num_cols_); if (num_rows_ == 0) return; - for (MatrixIndexT row = 0; row < num_rows_; row++, adata++, data += stride) + for (MatrixIndexT row = 0; row < num_rows_; row++, adata++, data += +stride) cblas_Xaxpy(num_cols_, alpha, adata, aStride, data, 1); } } @@ -510,7 +522,8 @@ void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, Real alpha_B_kj = alpha * p.second; Real *this_col_j = this->Data() + j; // Add to entire 'j'th column of *this at once using cblas_Xaxpy. - // pass stride to write a colmun as matrices are stored in row major order. + // pass stride to write a colmun as matrices are stored in row major +order. cblas_Xaxpy(this_num_rows, alpha_B_kj, a_col_k, A.stride_, this_col_j, this->stride_); //for (MatrixIndexT i = 0; i < this_num_rows; ++i) @@ -536,10 +549,11 @@ void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, Real alpha_B_jk = alpha * p.second; const Real *a_col_k = A.Data() + k; // Add to entire 'j'th column of *this at once using cblas_Xaxpy. - // pass stride to write a column as matrices are stored in row major order. + // pass stride to write a column as matrices are stored in row major +order. cblas_Xaxpy(this_num_rows, alpha_B_jk, a_col_k, A.stride_, this_col_j, this->stride_); - //for (MatrixIndexT i = 0; i < this_num_rows; ++i) + //for (MatrixIndexT i = 0; i < this_num_rows; ++i) // this_col_j[i*this->stride_] += alpha_B_jk * a_col_k[i*A.stride_]; } } @@ -593,7 +607,8 @@ void MatrixBase::AddDiagVecMat( Real *data = data_; const Real *Mdata = M.Data(), *vdata = v.Data(); if (num_rows_ == 0) return; - for (MatrixIndexT i = 0; i < num_rows; i++, data += stride, Mdata += M_row_stride, vdata++) + for (MatrixIndexT i = 0; i < num_rows; i++, data += stride, Mdata += +M_row_stride, vdata++) cblas_Xaxpy(num_cols, alpha * *vdata, Mdata, M_col_stride, data, 1); } @@ -627,7 +642,8 @@ void MatrixBase::AddMatDiagVec( if (num_rows_ == 0) return; for (MatrixIndexT i = 0; i < num_rows; i++){ for(MatrixIndexT j = 0; j < num_cols; j ++ ){ - data[i*stride + j] += alpha * vdata[j] * Mdata[i*M_row_stride + j*M_col_stride]; + data[i*stride + j] += alpha * vdata[j] * Mdata[i*M_row_stride + +j*M_col_stride]; } } } @@ -662,8 +678,10 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, KALDI_ASSERT(s != NULL && U_in != this && V_in != this); Matrix tmpU, tmpV; - if (U_in == NULL) tmpU.Resize(this->num_rows_, 1); // work-space if U_in empty. - if (V_in == NULL) tmpV.Resize(1, this->num_cols_); // work-space if V_in empty. + if (U_in == NULL) tmpU.Resize(this->num_rows_, 1); // work-space if U_in +empty. + if (V_in == NULL) tmpV.Resize(1, this->num_cols_); // work-space if V_in +empty. /// Impementation notes: /// Lapack works in column-order, therefore the dimensions of *this are @@ -697,8 +715,10 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, KaldiBlasInt result; // query for work space - char *u_job = const_cast(U_in ? "s" : "N"); // "s" == skinny, "N" == "none." - char *v_job = const_cast(V_in ? "s" : "N"); // "s" == skinny, "N" == "none." + char *u_job = const_cast(U_in ? "s" : "N"); // "s" == skinny, "N" == +"none." + char *v_job = const_cast(V_in ? "s" : "N"); // "s" == skinny, "N" == +"none." clapack_Xgesvd(v_job, u_job, &M, &N, data_, &LDA, s->Data(), @@ -707,7 +727,8 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, &work_query, &l_work, &result); - KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong arguments"); + KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong +arguments"); l_work = static_cast(work_query); Real *p_work; @@ -725,7 +746,8 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, p_work, &l_work, &result); - KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong arguments"); + KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong +arguments"); if (result != 0) { KALDI_WARN << "CLAPACK sgesvd_ : some weird convergence not satisfied"; @@ -734,170 +756,170 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, } #endif - +*/ // Copy constructor. Copies data to newly allocated memory. -template -Matrix::Matrix (const MatrixBase & M, - MatrixTransposeType trans/*=kNoTrans*/) +template +Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans /*=kNoTrans*/) : MatrixBase() { - if (trans == kNoTrans) { - Resize(M.num_rows_, M.num_cols_); - this->CopyFromMat(M); - } else { - Resize(M.num_cols_, M.num_rows_); - this->CopyFromMat(M, kTrans); - } + if (trans == kNoTrans) { + Resize(M.num_rows_, M.num_cols_); + this->CopyFromMat(M); + } else { + Resize(M.num_cols_, M.num_rows_); + this->CopyFromMat(M, kTrans); + } } // Copy constructor. Copies data to newly allocated memory. -template -Matrix::Matrix (const Matrix & M): - MatrixBase() { - Resize(M.num_rows_, M.num_cols_); - this->CopyFromMat(M); +template +Matrix::Matrix(const Matrix &M) : MatrixBase() { + Resize(M.num_rows_, M.num_cols_); + this->CopyFromMat(M); } /// Copy constructor from another type. -template -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans) : MatrixBase() { - if (trans == kNoTrans) { - Resize(M.NumRows(), M.NumCols()); - this->CopyFromMat(M); - } else { - Resize(M.NumCols(), M.NumRows()); - this->CopyFromMat(M, kTrans); - } +template +template +Matrix::Matrix(const MatrixBase &M, MatrixTransposeType trans) + : MatrixBase() { + if (trans == kNoTrans) { + Resize(M.NumRows(), M.NumCols()); + this->CopyFromMat(M); + } else { + Resize(M.NumCols(), M.NumRows()); + this->CopyFromMat(M, kTrans); + } } // Instantiate this constructor for float->double and double->float. -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans); -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans); +template Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans); +template Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans); -template +template inline void Matrix::Init(const MatrixIndexT rows, const MatrixIndexT cols, const MatrixStrideType stride_type) { - if (rows * cols == 0) { - KALDI_ASSERT(rows == 0 && cols == 0); - this->num_rows_ = 0; - this->num_cols_ = 0; - this->stride_ = 0; - this->data_ = NULL; - return; - } - KALDI_ASSERT(rows > 0 && cols > 0); - MatrixIndexT skip, stride; - size_t size; - void *data; // aligned memory block - void *temp; // memory block to be really freed - - // compute the size of skip and real cols - skip = ((16 / sizeof(Real)) - cols % (16 / sizeof(Real))) - % (16 / sizeof(Real)); - stride = cols + skip; - size = static_cast(rows) * static_cast(stride) - * sizeof(Real); - - // allocate the memory and set the right dimensions and parameters - if (NULL != (data = KALDI_MEMALIGN(16, size, &temp))) { - MatrixBase::data_ = static_cast (data); - MatrixBase::num_rows_ = rows; - MatrixBase::num_cols_ = cols; - MatrixBase::stride_ = (stride_type == kDefaultStride ? stride : cols); - } else { - throw std::bad_alloc(); - } + if (rows * cols == 0) { + KALDI_ASSERT(rows == 0 && cols == 0); + this->num_rows_ = 0; + this->num_cols_ = 0; + this->stride_ = 0; + this->data_ = NULL; + return; + } + KALDI_ASSERT(rows > 0 && cols > 0); + MatrixIndexT skip, stride; + size_t size; + void *data; // aligned memory block + void *temp; // memory block to be really freed + + // compute the size of skip and real cols + skip = ((16 / sizeof(Real)) - cols % (16 / sizeof(Real))) % + (16 / sizeof(Real)); + stride = cols + skip; + size = + static_cast(rows) * static_cast(stride) * sizeof(Real); + + // allocate the memory and set the right dimensions and parameters + if (NULL != (data = KALDI_MEMALIGN(16, size, &temp))) { + MatrixBase::data_ = static_cast(data); + MatrixBase::num_rows_ = rows; + MatrixBase::num_cols_ = cols; + MatrixBase::stride_ = + (stride_type == kDefaultStride ? stride : cols); + } else { + throw std::bad_alloc(); + } } -template +template void Matrix::Resize(const MatrixIndexT rows, const MatrixIndexT cols, MatrixResizeType resize_type, MatrixStrideType stride_type) { - // the next block uses recursion to handle what we have to do if - // resize_type == kCopyData. - if (resize_type == kCopyData) { - if (this->data_ == NULL || rows == 0) resize_type = kSetZero; // nothing to copy. - else if (rows == this->num_rows_ && cols == this->num_cols_ && - (stride_type == kDefaultStride || this->stride_ == this->num_cols_)) { return; } // nothing to do. - else { - // set tmp to a matrix of the desired size; if new matrix - // is bigger in some dimension, zero it. - MatrixResizeType new_resize_type = - (rows > this->num_rows_ || cols > this->num_cols_) ? kSetZero : kUndefined; - Matrix tmp(rows, cols, new_resize_type, stride_type); - MatrixIndexT rows_min = std::min(rows, this->num_rows_), - cols_min = std::min(cols, this->num_cols_); - tmp.Range(0, rows_min, 0, cols_min). - CopyFromMat(this->Range(0, rows_min, 0, cols_min)); - tmp.Swap(this); - // and now let tmp go out of scope, deleting what was in *this. - return; + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || rows == 0) + resize_type = kSetZero; // nothing to copy. + else if (rows == this->num_rows_ && cols == this->num_cols_ && + (stride_type == kDefaultStride || + this->stride_ == this->num_cols_)) { + return; + } // nothing to do. + else { + // set tmp to a matrix of the desired size; if new matrix + // is bigger in some dimension, zero it. + MatrixResizeType new_resize_type = + (rows > this->num_rows_ || cols > this->num_cols_) ? kSetZero + : kUndefined; + Matrix tmp(rows, cols, new_resize_type, stride_type); + MatrixIndexT rows_min = std::min(rows, this->num_rows_), + cols_min = std::min(cols, this->num_cols_); + tmp.Range(0, rows_min, 0, cols_min) + .CopyFromMat(this->Range(0, rows_min, 0, cols_min)); + tmp.Swap(this); + // and now let tmp go out of scope, deleting what was in *this. + return; + } } - } - // At this point, resize_type == kSetZero or kUndefined. + // At this point, resize_type == kSetZero or kUndefined. - if (MatrixBase::data_ != NULL) { - if (rows == MatrixBase::num_rows_ - && cols == MatrixBase::num_cols_) { - if (resize_type == kSetZero) - this->SetZero(); - return; + if (MatrixBase::data_ != NULL) { + if (rows == MatrixBase::num_rows_ && + cols == MatrixBase::num_cols_) { + if (resize_type == kSetZero) this->SetZero(); + return; + } else + Destroy(); } - else - Destroy(); - } - Init(rows, cols, stride_type); - if (resize_type == kSetZero) MatrixBase::SetZero(); + Init(rows, cols, stride_type); + if (resize_type == kSetZero) MatrixBase::SetZero(); } -template -template +template +template void MatrixBase::CopyFromMat(const MatrixBase &M, MatrixTransposeType Trans) { - if (sizeof(Real) == sizeof(OtherReal) && - static_cast(M.Data()) == - static_cast(this->Data())) { - // CopyFromMat called on same data. Nothing to do (except sanity checks). - KALDI_ASSERT(Trans == kNoTrans && M.NumRows() == NumRows() && - M.NumCols() == NumCols() && M.Stride() == Stride()); - return; - } - if (Trans == kNoTrans) { - KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == M.NumCols()); - for (MatrixIndexT i = 0; i < num_rows_; i++) - (*this).Row(i).CopyFromVec(M.Row(i)); - } else { - KALDI_ASSERT(num_cols_ == M.NumRows() && num_rows_ == M.NumCols()); - int32 this_stride = stride_, other_stride = M.Stride(); - Real *this_data = data_; - const OtherReal *other_data = M.Data(); - for (MatrixIndexT i = 0; i < num_rows_; i++) - for (MatrixIndexT j = 0; j < num_cols_; j++) - this_data[i * this_stride + j] = other_data[j * other_stride + i]; - } + if (sizeof(Real) == sizeof(OtherReal) && + static_cast(M.Data()) == + static_cast(this->Data())) { + // CopyFromMat called on same data. Nothing to do (except sanity + // checks). + KALDI_ASSERT(Trans == kNoTrans && M.NumRows() == NumRows() && + M.NumCols() == NumCols() && M.Stride() == Stride()); + return; + } + if (Trans == kNoTrans) { + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == M.NumCols()); + for (MatrixIndexT i = 0; i < num_rows_; i++) + (*this).Row(i).CopyFromVec(M.Row(i)); + } else { + KALDI_ASSERT(num_cols_ == M.NumRows() && num_rows_ == M.NumCols()); + int32 this_stride = stride_, other_stride = M.Stride(); + Real *this_data = data_; + const OtherReal *other_data = M.Data(); + for (MatrixIndexT i = 0; i < num_rows_; i++) + for (MatrixIndexT j = 0; j < num_cols_; j++) + this_data[i * this_stride + j] = + other_data[j * other_stride + i]; + } } // template instantiations. -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); - +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); + +/* // Specialize the template for CopyFromSp for float, float. template<> template<> @@ -992,103 +1014,100 @@ template void MatrixBase::CopyFromTp(const TpMatrix & M, MatrixTransposeType trans); - -template +*/ +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - if (stride_ == num_cols_) { - // one big copy operation. - const Real *rv_data = rv.Data(); - std::memcpy(data_, rv_data, sizeof(Real)*num_rows_*num_cols_); - } else { - const Real *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real *row_data = RowData(r); - for (MatrixIndexT c = 0; c < num_cols_; c++) { - row_data[c] = rv_data[c]; + if (rv.Dim() == num_rows_ * num_cols_) { + if (stride_ == num_cols_) { + // one big copy operation. + const Real *rv_data = rv.Data(); + std::memcpy(data_, rv_data, sizeof(Real) * num_rows_ * num_cols_); + } else { + const Real *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real *row_data = RowData(r); + for (MatrixIndexT c = 0; c < num_cols_; c++) { + row_data[c] = rv_data[c]; + } + rv_data += num_cols_; + } } - rv_data += num_cols_; - } + } else if (rv.Dim() == num_cols_) { + const Real *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) + std::memcpy(RowData(r), rv_data, sizeof(Real) * num_cols_); + } else { + KALDI_ERR << "Wrong sized arguments"; } - } else if (rv.Dim() == num_cols_) { - const Real *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) - std::memcpy(RowData(r), rv_data, sizeof(Real)*num_cols_); - } else { - KALDI_ERR << "Wrong sized arguments"; - } } -template -template +template +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - const OtherReal *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real *row_data = RowData(r); - for (MatrixIndexT c = 0; c < num_cols_; c++) { - row_data[c] = static_cast(rv_data[c]); - } - rv_data += num_cols_; + if (rv.Dim() == num_rows_ * num_cols_) { + const OtherReal *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real *row_data = RowData(r); + for (MatrixIndexT c = 0; c < num_cols_; c++) { + row_data[c] = static_cast(rv_data[c]); + } + rv_data += num_cols_; + } + } else if (rv.Dim() == num_cols_) { + const OtherReal *rv_data = rv.Data(); + Real *first_row_data = RowData(0); + for (MatrixIndexT c = 0; c < num_cols_; c++) + first_row_data[c] = rv_data[c]; + for (MatrixIndexT r = 1; r < num_rows_; r++) + std::memcpy(RowData(r), first_row_data, sizeof(Real) * num_cols_); + } else { + KALDI_ERR << "Wrong sized arguments."; } - } else if (rv.Dim() == num_cols_) { - const OtherReal *rv_data = rv.Data(); - Real *first_row_data = RowData(0); - for (MatrixIndexT c = 0; c < num_cols_; c++) - first_row_data[c] = rv_data[c]; - for (MatrixIndexT r = 1; r < num_rows_; r++) - std::memcpy(RowData(r), first_row_data, sizeof(Real)*num_cols_); - } else { - KALDI_ERR << "Wrong sized arguments."; - } } -template -void MatrixBase::CopyRowsFromVec(const VectorBase &rv); -template -void MatrixBase::CopyRowsFromVec(const VectorBase &rv); +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv); +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv); -template +template void MatrixBase::CopyColsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - const Real *v_inc_data = rv.Data(); - Real *m_inc_data = data_; + if (rv.Dim() == num_rows_ * num_cols_) { + const Real *v_inc_data = rv.Data(); + Real *m_inc_data = data_; - for (MatrixIndexT c = 0; c < num_cols_; c++) { - for (MatrixIndexT r = 0; r < num_rows_; r++) { - m_inc_data[r * stride_] = v_inc_data[r]; - } - v_inc_data += num_rows_; - m_inc_data ++; - } - } else if (rv.Dim() == num_rows_) { - const Real *v_inc_data = rv.Data(); - Real *m_inc_data = data_; - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real value = *(v_inc_data++); - for (MatrixIndexT c = 0; c < num_cols_; c++) - m_inc_data[c] = value; - m_inc_data += stride_; + for (MatrixIndexT c = 0; c < num_cols_; c++) { + for (MatrixIndexT r = 0; r < num_rows_; r++) { + m_inc_data[r * stride_] = v_inc_data[r]; + } + v_inc_data += num_rows_; + m_inc_data++; + } + } else if (rv.Dim() == num_rows_) { + const Real *v_inc_data = rv.Data(); + Real *m_inc_data = data_; + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real value = *(v_inc_data++); + for (MatrixIndexT c = 0; c < num_cols_; c++) m_inc_data[c] = value; + m_inc_data += stride_; + } + } else { + KALDI_ERR << "Wrong size of arguments."; } - } else { - KALDI_ERR << "Wrong size of arguments."; - } } +template +void MatrixBase::CopyRowFromVec(const VectorBase &rv, + const MatrixIndexT row) { + KALDI_ASSERT(rv.Dim() == num_cols_ && + static_cast(row) < + static_cast(num_rows_)); -template -void MatrixBase::CopyRowFromVec(const VectorBase &rv, const MatrixIndexT row) { - KALDI_ASSERT(rv.Dim() == num_cols_ && - static_cast(row) < - static_cast(num_rows_)); - - const Real *rv_data = rv.Data(); - Real *row_data = RowData(row); + const Real *rv_data = rv.Data(); + Real *row_data = RowData(row); - std::memcpy(row_data, rv_data, num_cols_ * sizeof(Real)); + std::memcpy(row_data, rv_data, num_cols_ * sizeof(Real)); } - +/* template void MatrixBase::CopyDiagFromVec(const VectorBase &rv) { KALDI_ASSERT(rv.Dim() == std::min(num_cols_, num_rows_)); @@ -1096,46 +1115,46 @@ void MatrixBase::CopyDiagFromVec(const VectorBase &rv) { Real *my_data = this->Data(); for (; rv_data != rv_end; rv_data++, my_data += (this->stride_+1)) *my_data = *rv_data; -} +}*/ -template +template void MatrixBase::CopyColFromVec(const VectorBase &rv, const MatrixIndexT col) { - KALDI_ASSERT(rv.Dim() == num_rows_ && - static_cast(col) < - static_cast(num_cols_)); + KALDI_ASSERT(rv.Dim() == num_rows_ && + static_cast(col) < + static_cast(num_cols_)); - const Real *rv_data = rv.Data(); - Real *col_data = data_ + col; + const Real *rv_data = rv.Data(); + Real *col_data = data_ + col; - for (MatrixIndexT r = 0; r < num_rows_; r++) - col_data[r * stride_] = rv_data[r]; + for (MatrixIndexT r = 0; r < num_rows_; r++) + col_data[r * stride_] = rv_data[r]; } - -template +template void Matrix::RemoveRow(MatrixIndexT i) { - KALDI_ASSERT(static_cast(i) < - static_cast(MatrixBase::num_rows_) - && "Access out of matrix"); - for (MatrixIndexT j = i + 1; j < MatrixBase::num_rows_; j++) - MatrixBase::Row(j-1).CopyFromVec( MatrixBase::Row(j)); - MatrixBase::num_rows_--; + KALDI_ASSERT( + static_cast(i) < + static_cast(MatrixBase::num_rows_) && + "Access out of matrix"); + for (MatrixIndexT j = i + 1; j < MatrixBase::num_rows_; j++) + MatrixBase::Row(j - 1).CopyFromVec(MatrixBase::Row(j)); + MatrixBase::num_rows_--; } -template +template void Matrix::Destroy() { - // we need to free the data block if it was defined - if (NULL != MatrixBase::data_) - KALDI_MEMALIGN_FREE( MatrixBase::data_); - MatrixBase::data_ = NULL; - MatrixBase::num_rows_ = MatrixBase::num_cols_ - = MatrixBase::stride_ = 0; + // we need to free the data block if it was defined + if (NULL != MatrixBase::data_) + KALDI_MEMALIGN_FREE(MatrixBase::data_); + MatrixBase::data_ = NULL; + MatrixBase::num_rows_ = MatrixBase::num_cols_ = + MatrixBase::stride_ = 0; } - +/* template void MatrixBase::MulElements(const MatrixBase &a) { KALDI_ASSERT(a.NumRows() == num_rows_ && a.NumCols() == num_cols_); @@ -1255,7 +1274,8 @@ template void MatrixBase::GroupPnormDeriv(const MatrixBase &input, const MatrixBase &output, Real power) { - KALDI_ASSERT(input.NumCols() == this->NumCols() && input.NumRows() == this->NumRows()); + KALDI_ASSERT(input.NumCols() == this->NumCols() && input.NumRows() == +this->NumRows()); KALDI_ASSERT(this->NumCols() % output.NumCols() == 0 && this->NumRows() == output.NumRows()); @@ -1325,25 +1345,27 @@ void MatrixBase::MulColsVec(const VectorBase &scale) { } } } +*/ -template +template void MatrixBase::SetZero() { - if (num_cols_ == stride_) - memset(data_, 0, sizeof(Real)*num_rows_*num_cols_); - else - for (MatrixIndexT row = 0; row < num_rows_; row++) - memset(data_ + row*stride_, 0, sizeof(Real)*num_cols_); + if (num_cols_ == stride_) + memset(data_, 0, sizeof(Real) * num_rows_ * num_cols_); + else + for (MatrixIndexT row = 0; row < num_rows_; row++) + memset(data_ + row * stride_, 0, sizeof(Real) * num_cols_); } -template +template void MatrixBase::Set(Real value) { - for (MatrixIndexT row = 0; row < num_rows_; row++) { - for (MatrixIndexT col = 0; col < num_cols_; col++) { - (*this)(row, col) = value; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + (*this)(row, col) = value; + } } - } } +/* template void MatrixBase::SetUnit() { SetZero(); @@ -1360,7 +1382,8 @@ void MatrixBase::SetRandn() { for (MatrixIndexT col = 0; col < nc; col += 2) { kaldi::RandGauss2(row_data + col, row_data + col + 1, &rstate); } - if (nc != num_cols_) row_data[nc] = static_cast(kaldi::RandGauss(&rstate)); + if (nc != num_cols_) row_data[nc] = +static_cast(kaldi::RandGauss(&rstate)); } } @@ -1374,305 +1397,307 @@ void MatrixBase::SetRandUniform() { } } } +*/ -template +template void MatrixBase::Write(std::ostream &os, bool binary) const { - if (!os.good()) { - KALDI_ERR << "Failed to write matrix to stream: stream not good"; - } - if (binary) { // Use separate binary and text formats, - // since in binary mode we need to know if it's float or double. - std::string my_token = (sizeof(Real) == 4 ? "FM" : "DM"); - - WriteToken(os, binary, my_token); - { - int32 rows = this->num_rows_; // make the size 32-bit on disk. - int32 cols = this->num_cols_; - KALDI_ASSERT(this->num_rows_ == (MatrixIndexT) rows); - KALDI_ASSERT(this->num_cols_ == (MatrixIndexT) cols); - WriteBasicType(os, binary, rows); - WriteBasicType(os, binary, cols); - } - if (Stride() == NumCols()) - os.write(reinterpret_cast (Data()), sizeof(Real) - * static_cast(num_rows_) * static_cast(num_cols_)); - else - for (MatrixIndexT i = 0; i < num_rows_; i++) - os.write(reinterpret_cast (RowData(i)), sizeof(Real) - * num_cols_); if (!os.good()) { - KALDI_ERR << "Failed to write matrix to stream"; - } - } else { // text mode. - if (num_cols_ == 0) { - os << " [ ]\n"; - } else { - os << " ["; - for (MatrixIndexT i = 0; i < num_rows_; i++) { - os << "\n "; - for (MatrixIndexT j = 0; j < num_cols_; j++) - os << (*this)(i, j) << " "; - } - os << "]\n"; + KALDI_ERR << "Failed to write matrix to stream: stream not good"; + } + if (binary) { // Use separate binary and text formats, + // since in binary mode we need to know if it's float or double. + std::string my_token = (sizeof(Real) == 4 ? "FM" : "DM"); + + WriteToken(os, binary, my_token); + { + int32 rows = this->num_rows_; // make the size 32-bit on disk. + int32 cols = this->num_cols_; + KALDI_ASSERT(this->num_rows_ == (MatrixIndexT)rows); + KALDI_ASSERT(this->num_cols_ == (MatrixIndexT)cols); + WriteBasicType(os, binary, rows); + WriteBasicType(os, binary, cols); + } + if (Stride() == NumCols()) + os.write(reinterpret_cast(Data()), + sizeof(Real) * static_cast(num_rows_) * + static_cast(num_cols_)); + else + for (MatrixIndexT i = 0; i < num_rows_; i++) + os.write(reinterpret_cast(RowData(i)), + sizeof(Real) * num_cols_); + if (!os.good()) { + KALDI_ERR << "Failed to write matrix to stream"; + } + } else { // text mode. + if (num_cols_ == 0) { + os << " [ ]\n"; + } else { + os << " ["; + for (MatrixIndexT i = 0; i < num_rows_; i++) { + os << "\n "; + for (MatrixIndexT j = 0; j < num_cols_; j++) + os << (*this)(i, j) << " "; + } + os << "]\n"; + } } - } -} - - -template -void MatrixBase::Read(std::istream & is, bool binary, bool add) { - if (add) { - Matrix tmp(num_rows_, num_cols_); - tmp.Read(is, binary, false); // read without adding. - if (tmp.num_rows_ != this->num_rows_ || tmp.num_cols_ != this->num_cols_) - KALDI_ERR << "MatrixBase::Read, size mismatch " - << this->num_rows_ << ", " << this->num_cols_ - << " vs. " << tmp.num_rows_ << ", " << tmp.num_cols_; - this->AddMat(1.0, tmp); - return; - } - // now assume add == false. - - // In order to avoid rewriting this, we just declare a Matrix and - // use it to read the data, then copy. - Matrix tmp; - tmp.Read(is, binary, false); - if (tmp.NumRows() != NumRows() || tmp.NumCols() != NumCols()) { - KALDI_ERR << "MatrixBase::Read, size mismatch " - << NumRows() << " x " << NumCols() << " versus " - << tmp.NumRows() << " x " << tmp.NumCols(); - } - CopyFromMat(tmp); } -template -void Matrix::Read(std::istream & is, bool binary, bool add) { - if (add) { +template +void MatrixBase::Read(std::istream &is, bool binary) { + // In order to avoid rewriting this, we just declare a Matrix and + // use it to read the data, then copy. Matrix tmp; - tmp.Read(is, binary, false); // read without adding. - if (this->num_rows_ == 0) this->Resize(tmp.num_rows_, tmp.num_cols_); - else { - if (this->num_rows_ != tmp.num_rows_ || this->num_cols_ != tmp.num_cols_) { - if (tmp.num_rows_ == 0) return; // do nothing in this case. - else KALDI_ERR << "Matrix::Read, size mismatch " - << this->num_rows_ << ", " << this->num_cols_ - << " vs. " << tmp.num_rows_ << ", " << tmp.num_cols_; - } + tmp.Read(is, binary); + if (tmp.NumRows() != NumRows() || tmp.NumCols() != NumCols()) { + KALDI_ERR << "MatrixBase::Read, size mismatch " << NumRows() + << " x " << NumCols() << " versus " << tmp.NumRows() << " x " + << tmp.NumCols(); } - this->AddMat(1.0, tmp); - return; - } + CopyFromMat(tmp); +} - // now assume add == false. - MatrixIndexT pos_at_start = is.tellg(); - std::ostringstream specific_error; - if (binary) { // Read in binary mode. - int peekval = Peek(is, binary); - if (peekval == 'C') { - // This code enables us to read CompressedMatrix as a regular matrix. - CompressedMatrix compressed_mat; - compressed_mat.Read(is, binary); // at this point, add == false. - this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); - compressed_mat.CopyToMat(this); - return; - } - const char *my_token = (sizeof(Real) == 4 ? "FM" : "DM"); - char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); - if (peekval == other_token_start) { // need to instantiate the other type to read it. - typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. - Matrix other(this->num_rows_, this->num_cols_); - other.Read(is, binary, false); // add is false at this point anyway. - this->Resize(other.NumRows(), other.NumCols()); - this->CopyFromMat(other); - return; - } - std::string token; - ReadToken(is, binary, &token); - if (token != my_token) { - if (token.length() > 20) token = token.substr(0, 17) + "..."; - specific_error << ": Expected token " << my_token << ", got " << token; - goto bad; - } - int32 rows, cols; - ReadBasicType(is, binary, &rows); // throws on error. - ReadBasicType(is, binary, &cols); // throws on error. - if ((MatrixIndexT)rows != this->num_rows_ || (MatrixIndexT)cols != this->num_cols_) { - this->Resize(rows, cols); - } - if (this->Stride() == this->NumCols() && rows*cols!=0) { - is.read(reinterpret_cast(this->Data()), - sizeof(Real)*rows*cols); - if (is.fail()) goto bad; - } else { - for (MatrixIndexT i = 0; i < (MatrixIndexT)rows; i++) { - is.read(reinterpret_cast(this->RowData(i)), sizeof(Real)*cols); - if (is.fail()) goto bad; - } - } - if (is.eof()) return; - if (is.fail()) goto bad; - return; - } else { // Text mode. - std::string str; - is >> str; // get a token - if (is.fail()) { specific_error << ": Expected \"[\", got EOF"; goto bad; } - // if ((str.compare("DM") == 0) || (str.compare("FM") == 0)) { // Back compatibility. - // is >> str; // get #rows - // is >> str; // get #cols - // is >> str; // get "[" - // } - if (str == "[]") { Resize(0, 0); return; } // Be tolerant of variants. - else if (str != "[") { - if (str.length() > 20) str = str.substr(0, 17) + "..."; - specific_error << ": Expected \"[\", got \"" << str << '"'; - goto bad; - } - // At this point, we have read "[". - std::vector* > data; - std::vector *cur_row = new std::vector; - while (1) { - int i = is.peek(); - if (i == -1) { specific_error << "Got EOF while reading matrix data"; goto cleanup; } - else if (static_cast(i) == ']') { // Finished reading matrix. - is.get(); // eat the "]". - i = is.peek(); - if (static_cast(i) == '\r') { - is.get(); - is.get(); // get \r\n (must eat what we wrote) - } else if (static_cast(i) == '\n') { is.get(); } // get \n (must eat what we wrote) - if (is.fail()) { - KALDI_WARN << "After end of matrix data, read error."; - // we got the data we needed, so just warn for this error. +template +void Matrix::Read(std::istream &is, bool binary) { + // now assume add == false. + MatrixIndexT pos_at_start = is.tellg(); + std::ostringstream specific_error; + + if (binary) { // Read in binary mode. + int peekval = Peek(is, binary); + if (peekval == 'C') { + // This code enables us to read CompressedMatrix as a regular + // matrix. + // CompressedMatrix compressed_mat; + // compressed_mat.Read(is, binary); // at this point, add == false. + // this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); + // compressed_mat.CopyToMat(this); + return; } - // Now process the data. - if (!cur_row->empty()) data.push_back(cur_row); - else delete(cur_row); - cur_row = NULL; - if (data.empty()) { this->Resize(0, 0); return; } - else { - int32 num_rows = data.size(), num_cols = data[0]->size(); - this->Resize(num_rows, num_cols); - for (int32 i = 0; i < num_rows; i++) { - if (static_cast(data[i]->size()) != num_cols) { - specific_error << "Matrix has inconsistent #cols: " << num_cols - << " vs." << data[i]->size() << " (processing row" - << i << ")"; - goto cleanup; + const char *my_token = (sizeof(Real) == 4 ? "FM" : "DM"); + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + if (peekval == other_token_start) { // need to instantiate the other + // type to read it. + typedef typename OtherReal::Real OtherType; // if Real == + // float, + // OtherType == + // double, and + // vice versa. + Matrix other(this->num_rows_, this->num_cols_); + other.Read(is, binary); // add is false at this point anyway. + this->Resize(other.NumRows(), other.NumCols()); + this->CopyFromMat(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if (token.length() > 20) token = token.substr(0, 17) + "..."; + specific_error << ": Expected token " << my_token << ", got " + << token; + goto bad; + } + int32 rows, cols; + ReadBasicType(is, binary, &rows); // throws on error. + ReadBasicType(is, binary, &cols); // throws on error. + if ((MatrixIndexT)rows != this->num_rows_ || + (MatrixIndexT)cols != this->num_cols_) { + this->Resize(rows, cols); + } + if (this->Stride() == this->NumCols() && rows * cols != 0) { + is.read(reinterpret_cast(this->Data()), + sizeof(Real) * rows * cols); + if (is.fail()) goto bad; + } else { + for (MatrixIndexT i = 0; i < (MatrixIndexT)rows; i++) { + is.read(reinterpret_cast(this->RowData(i)), + sizeof(Real) * cols); + if (is.fail()) goto bad; } - for (int32 j = 0; j < num_cols; j++) - (*this)(i, j) = (*(data[i]))[j]; - delete data[i]; - data[i] = NULL; - } } + if (is.eof()) return; + if (is.fail()) goto bad; return; - } else if (static_cast(i) == '\n' || static_cast(i) == ';') { - // End of matrix row. - is.get(); - if (cur_row->size() != 0) { - data.push_back(cur_row); - cur_row = new std::vector; - cur_row->reserve(data.back()->size()); - } - } else if ( (i >= '0' && i <= '9') || i == '-' ) { // A number... - Real r; - is >> r; + } else { // Text mode. + std::string str; + is >> str; // get a token if (is.fail()) { - specific_error << "Stream failure/EOF while reading matrix data."; - goto cleanup; + specific_error << ": Expected \"[\", got EOF"; + goto bad; } - cur_row->push_back(r); - } else if (isspace(i)) { - is.get(); // eat the space and do nothing. - } else { // NaN or inf or error. - std::string str; - is >> str; - if (!KALDI_STRCASECMP(str.c_str(), "inf") || - !KALDI_STRCASECMP(str.c_str(), "infinity")) { - cur_row->push_back(std::numeric_limits::infinity()); - KALDI_WARN << "Reading infinite value into matrix."; - } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { - cur_row->push_back(std::numeric_limits::quiet_NaN()); - KALDI_WARN << "Reading NaN value into matrix."; - } else { - if (str.length() > 20) str = str.substr(0, 17) + "..."; - specific_error << "Expecting numeric matrix data, got " << str; - goto cleanup; + // if ((str.compare("DM") == 0) || (str.compare("FM") == 0)) { // Back + // compatibility. + // is >> str; // get #rows + // is >> str; // get #cols + // is >> str; // get "[" + // } + if (str == "[]") { + Resize(0, 0); + return; + } // Be tolerant of variants. + else if (str != "[") { + if (str.length() > 20) str = str.substr(0, 17) + "..."; + specific_error << ": Expected \"[\", got \"" << str << '"'; + goto bad; + } + // At this point, we have read "[". + std::vector *> data; + std::vector *cur_row = new std::vector; + while (1) { + int i = is.peek(); + if (i == -1) { + specific_error << "Got EOF while reading matrix data"; + goto cleanup; + } else if (static_cast(i) == + ']') { // Finished reading matrix. + is.get(); // eat the "]". + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + } else if (static_cast(i) == '\n') { + is.get(); + } // get \n (must eat what we wrote) + if (is.fail()) { + KALDI_WARN << "After end of matrix data, read error."; + // we got the data we needed, so just warn for this error. + } + // Now process the data. + if (!cur_row->empty()) + data.push_back(cur_row); + else + delete (cur_row); + cur_row = NULL; + if (data.empty()) { + this->Resize(0, 0); + return; + } else { + int32 num_rows = data.size(), num_cols = data[0]->size(); + this->Resize(num_rows, num_cols); + for (int32 i = 0; i < num_rows; i++) { + if (static_cast(data[i]->size()) != num_cols) { + specific_error + << "Matrix has inconsistent #cols: " << num_cols + << " vs." << data[i]->size() + << " (processing row" << i << ")"; + goto cleanup; + } + for (int32 j = 0; j < num_cols; j++) + (*this)(i, j) = (*(data[i]))[j]; + delete data[i]; + data[i] = NULL; + } + } + return; + } else if (static_cast(i) == '\n' || + static_cast(i) == ';') { + // End of matrix row. + is.get(); + if (cur_row->size() != 0) { + data.push_back(cur_row); + cur_row = new std::vector; + cur_row->reserve(data.back()->size()); + } + } else if ((i >= '0' && i <= '9') || i == '-') { // A number... + Real r; + is >> r; + if (is.fail()) { + specific_error + << "Stream failure/EOF while reading matrix data."; + goto cleanup; + } + cur_row->push_back(r); + } else if (isspace(i)) { + is.get(); // eat the space and do nothing. + } else { // NaN or inf or error. + std::string str; + is >> str; + if (!KALDI_STRCASECMP(str.c_str(), "inf") || + !KALDI_STRCASECMP(str.c_str(), "infinity")) { + cur_row->push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into matrix."; + } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { + cur_row->push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into matrix."; + } else { + if (str.length() > 20) str = str.substr(0, 17) + "..."; + specific_error << "Expecting numeric matrix data, got " + << str; + goto cleanup; + } + } } - } - } // Note, we never leave the while () loop before this // line (we return from it.) - cleanup: // We only reach here in case of error in the while loop above. - if(cur_row != NULL) - delete cur_row; - for (size_t i = 0; i < data.size(); i++) - if(data[i] != NULL) - delete data[i]; - // and then go on to "bad" below, where we print error. - } + cleanup: // We only reach here in case of error in the while loop above. + if (cur_row != NULL) delete cur_row; + for (size_t i = 0; i < data.size(); i++) + if (data[i] != NULL) delete data[i]; + // and then go on to "bad" below, where we print error. + } bad: - KALDI_ERR << "Failed to read matrix from stream. " << specific_error.str() - << " File position at start is " - << pos_at_start << ", currently " << is.tellg(); + KALDI_ERR << "Failed to read matrix from stream. " << specific_error.str() + << " File position at start is " << pos_at_start << ", currently " + << is.tellg(); } // Constructor... note that this is not const-safe as it would // be quite complicated to implement a "const SubMatrix" class that // would not allow its contents to be changed. -template +template SubMatrix::SubMatrix(const MatrixBase &M, const MatrixIndexT ro, const MatrixIndexT r, const MatrixIndexT co, const MatrixIndexT c) { - if (r == 0 || c == 0) { - // we support the empty sub-matrix as a special case. - KALDI_ASSERT(c == 0 && r == 0); - this->data_ = NULL; - this->num_cols_ = 0; - this->num_rows_ = 0; - this->stride_ = 0; - return; - } - KALDI_ASSERT(static_cast(ro) < - static_cast(M.num_rows_) && - static_cast(co) < - static_cast(M.num_cols_) && - static_cast(r) <= - static_cast(M.num_rows_ - ro) && - static_cast(c) <= - static_cast(M.num_cols_ - co)); - // point to the beginning of window - MatrixBase::num_rows_ = r; - MatrixBase::num_cols_ = c; - MatrixBase::stride_ = M.Stride(); - MatrixBase::data_ = M.Data_workaround() + - static_cast(co) + - static_cast(ro) * static_cast(M.Stride()); + if (r == 0 || c == 0) { + // we support the empty sub-matrix as a special case. + KALDI_ASSERT(c == 0 && r == 0); + this->data_ = NULL; + this->num_cols_ = 0; + this->num_rows_ = 0; + this->stride_ = 0; + return; + } + KALDI_ASSERT(static_cast(ro) < + static_cast(M.num_rows_) && + static_cast(co) < + static_cast(M.num_cols_) && + static_cast(r) <= + static_cast(M.num_rows_ - ro) && + static_cast(c) <= + static_cast(M.num_cols_ - co)); + // point to the begining of window + MatrixBase::num_rows_ = r; + MatrixBase::num_cols_ = c; + MatrixBase::stride_ = M.Stride(); + MatrixBase::data_ = + M.Data_workaround() + static_cast(co) + + static_cast(ro) * static_cast(M.Stride()); } -template +template SubMatrix::SubMatrix(Real *data, MatrixIndexT num_rows, MatrixIndexT num_cols, - MatrixIndexT stride): - MatrixBase(data, num_cols, num_rows, stride) { // caution: reversed order! - if (data == NULL) { - KALDI_ASSERT(num_rows * num_cols == 0); - this->num_rows_ = 0; - this->num_cols_ = 0; - this->stride_ = 0; - } else { - KALDI_ASSERT(this->stride_ >= this->num_cols_); - } + MatrixIndexT stride) + : MatrixBase( + data, num_cols, num_rows, stride) { // caution: reversed order! + if (data == NULL) { + KALDI_ASSERT(num_rows * num_cols == 0); + this->num_rows_ = 0; + this->num_cols_ = 0; + this->stride_ = 0; + } else { + KALDI_ASSERT(this->stride_ >= this->num_cols_); + } } - +/* template void MatrixBase::Add(const Real alpha) { Real *data = data_; @@ -1697,9 +1722,11 @@ Real MatrixBase::Cond() const { KALDI_ASSERT(num_rows_ > 0&&num_cols_ > 0); Vector singular_values(std::min(num_rows_, num_cols_)); Svd(&singular_values); // Get singular values... - Real min = singular_values(0), max = singular_values(0); // both absolute values... + Real min = singular_values(0), max = singular_values(0); // both absolute +values... for (MatrixIndexT i = 1;i < singular_values.Dim();i++) { - min = std::min((Real)std::abs(singular_values(i)), min); max = std::max((Real)std::abs(singular_values(i)), max); + min = std::min((Real)std::abs(singular_values(i)), min); max = +std::max((Real)std::abs(singular_values(i)), max); } if (min > 0) return max/min; else return std::numeric_limits::infinity(); @@ -1709,7 +1736,8 @@ template Real MatrixBase::Trace(bool check_square) const { KALDI_ASSERT(!check_square || num_rows_ == num_cols_); Real ans = 0.0; - for (MatrixIndexT r = 0;r < std::min(num_rows_, num_cols_);r++) ans += data_ [r + stride_*r]; + for (MatrixIndexT r = 0;r < std::min(num_rows_, num_cols_);r++) ans += data_ +[r + stride_*r]; return ans; } @@ -1739,22 +1767,29 @@ Real MatrixBase::Min() const { template void MatrixBase::AddMatMatMat(Real alpha, - const MatrixBase &A, MatrixTransposeType transA, - const MatrixBase &B, MatrixTransposeType transB, - const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &A, +MatrixTransposeType transA, + const MatrixBase &B, +MatrixTransposeType transB, + const MatrixBase &C, +MatrixTransposeType transC, Real beta) { - // Note on time taken with different orders of computation. Assume not transposed in this / - // discussion. Firstly, normalize expressions using A.NumCols == B.NumRows and B.NumCols == C.NumRows, prefer + // Note on time taken with different orders of computation. Assume not +transposed in this / + // discussion. Firstly, normalize expressions using A.NumCols == B.NumRows and +B.NumCols == C.NumRows, prefer // rows where there is a choice. // time taken for (AB) is: A.NumRows*B.NumRows*C.Rows // time taken for (AB)C is A.NumRows*C.NumRows*C.Cols - // so this order is A.NumRows*B.NumRows*C.NumRows + A.NumRows*C.NumRows*C.NumCols. + // so this order is A.NumRows*B.NumRows*C.NumRows + +A.NumRows*C.NumRows*C.NumCols. // time taken for (BC) is: B.NumRows*C.NumRows*C.Cols // time taken for A(BC) is: A.NumRows*B.NumRows*C.Cols // so this order is B.NumRows*C.NumRows*C.NumCols + A.NumRows*B.NumRows*C.Cols - MatrixIndexT ARows = A.num_rows_, ACols = A.num_cols_, BRows = B.num_rows_, BCols = B.num_cols_, + MatrixIndexT ARows = A.num_rows_, ACols = A.num_cols_, BRows = B.num_rows_, +BCols = B.num_cols_, CRows = C.num_rows_, CCols = C.num_cols_; if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); @@ -1778,58 +1813,71 @@ void MatrixBase::AddMatMatMat(Real alpha, template -void MatrixBase::DestructiveSvd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) { +void MatrixBase::DestructiveSvd(VectorBase *s, MatrixBase *U, +MatrixBase *Vt) { // Svd, *this = U*diag(s)*Vt. // With (*this).num_rows_ == m, (*this).num_cols_ == n, - // Support only skinny Svd with m>=n (NumRows>=NumCols), and zero sizes for U and Vt mean + // Support only skinny Svd with m>=n (NumRows>=NumCols), and zero sizes for U +and Vt mean // we do not want that output. We expect that s.Dim() == m, // U is either 0 by 0 or m by n, and rv is either 0 by 0 or n by n. // Throws exception on error. - KALDI_ASSERT(num_rows_>=num_cols_ && "Svd requires that #rows by >= #cols."); // For compatibility with JAMA code. + KALDI_ASSERT(num_rows_>=num_cols_ && "Svd requires that #rows by >= #cols."); +// For compatibility with JAMA code. KALDI_ASSERT(s->Dim() == num_cols_); // s should be the smaller dim. - KALDI_ASSERT(U == NULL || (U->num_rows_ == num_rows_&&U->num_cols_ == num_cols_)); - KALDI_ASSERT(Vt == NULL || (Vt->num_rows_ == num_cols_&&Vt->num_cols_ == num_cols_)); + KALDI_ASSERT(U == NULL || (U->num_rows_ == num_rows_&&U->num_cols_ == +num_cols_)); + KALDI_ASSERT(Vt == NULL || (Vt->num_rows_ == num_cols_&&Vt->num_cols_ == +num_cols_)); Real prescale = 1.0; - if ( std::abs((*this)(0, 0) ) < 1.0e-30) { // Very tiny value... can cause problems in Svd. + if ( std::abs((*this)(0, 0) ) < 1.0e-30) { // Very tiny value... can cause +problems in Svd. Real max_elem = LargestAbsElem(); if (max_elem != 0) { prescale = 1.0 / max_elem; - if (std::abs(prescale) == std::numeric_limits::infinity()) { prescale = 1.0e+40; } + if (std::abs(prescale) == std::numeric_limits::infinity()) { +prescale = 1.0e+40; } (*this).Scale(prescale); } } #if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) - // "S" == skinny Svd (only one we support because of compatibility with Jama one which is only skinny), + // "S" == skinny Svd (only one we support because of compatibility with Jama +one which is only skinny), // "N"== no eigenvectors wanted. LapackGesvd(s, U, Vt); #else /* if (num_rows_ > 1 && num_cols_ > 1 && (*this)(0, 0) == (*this)(1, 1) - && Max() == Min() && (*this)(0, 0) != 0.0) { // special case that JamaSvd sometimes crashes on. - KALDI_WARN << "Jama SVD crashes on this type of matrix, perturbing it to prevent crash."; + && Max() == Min() && (*this)(0, 0) != 0.0) { // special case that JamaSvd +sometimes crashes on. + KALDI_WARN << "Jama SVD crashes on this type of matrix, perturbing it to +prevent crash."; for(int32 i = 0; i < NumRows(); i++) (*this)(i, i) *= 1.00001; }*/ - bool ans = JamaSvd(s, U, Vt); - if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the transpose inside the JamaSvd routine. note, Vt is square. - if (!ans) { - KALDI_ERR << "Error doing Svd"; // This one will be caught. - } -#endif - if (prescale != 1.0) s->Scale(1.0/prescale); -} - -template -void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) const { +// bool ans = JamaSvd(s, U, Vt); +// if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the +// transpose inside the JamaSvd routine. note, Vt is square. +// if (!ans) { +// KALDI_ERR << "Error doing Svd"; // This one will be caught. +//} +//#endif +// if (prescale != 1.0) s->Scale(1.0/prescale); +//} +/* +template +void MatrixBase::Svd(VectorBase *s, MatrixBase *U, +MatrixBase *Vt) const { try { if (num_rows_ >= num_cols_) { Matrix tmp(*this); tmp.DestructiveSvd(s, U, Vt); } else { Matrix tmp(*this, kTrans); // transpose of *this. - // rVt will have different dim so cannot transpose in-place --> use a temp matrix. + // rVt will have different dim so cannot transpose in-place --> use a temp +matrix. Matrix Vt_Trans(Vt ? Vt->num_cols_ : 0, Vt ? Vt->num_rows_ : 0); // U will be transpose tmp.DestructiveSvd(s, Vt ? &Vt_Trans : NULL, U); @@ -1838,7 +1886,8 @@ void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase< } } catch (...) { KALDI_ERR << "Error doing Svd (did not converge), first part of matrix is\n" - << SubMatrix(*this, 0, std::min((MatrixIndexT)10, num_rows_), + << SubMatrix(*this, 0, std::min((MatrixIndexT)10, +num_rows_), 0, std::min((MatrixIndexT)10, num_cols_)) << ", min and max are: " << Min() << ", " << Max(); } @@ -1851,7 +1900,8 @@ bool MatrixBase::IsSymmetric(Real cutoff) const { Real bad_sum = 0.0, good_sum = 0.0; for (MatrixIndexT i = 0;i < R;i++) { for (MatrixIndexT j = 0;j < i;j++) { - Real a = (*this)(i, j), b = (*this)(j, i), avg = 0.5*(a+b), diff = 0.5*(a-b); + Real a = (*this)(i, j), b = (*this)(j, i), avg = 0.5*(a+b), diff = +0.5*(a-b); good_sum += std::abs(avg); bad_sum += std::abs(diff); } good_sum += std::abs((*this)(i, i)); @@ -1892,7 +1942,8 @@ bool MatrixBase::IsUnit(Real cutoff) const { Real bad_max = 0.0; for (MatrixIndexT i = 0; i < R;i++) for (MatrixIndexT j = 0; j < C;j++) - bad_max = std::max(bad_max, static_cast(std::abs( (*this)(i, j) - (i == j?1.0:0.0)))); + bad_max = std::max(bad_max, static_cast(std::abs( (*this)(i, j) - (i +== j?1.0:0.0)))); return (bad_max <= cutoff); } @@ -1912,7 +1963,8 @@ Real MatrixBase::FrobeniusNorm() const{ } template -bool MatrixBase::ApproxEqual(const MatrixBase &other, float tol) const { +bool MatrixBase::ApproxEqual(const MatrixBase &other, float tol) +const { if (num_rows_ != other.num_rows_ || num_cols_ != other.num_cols_) KALDI_ERR << "ApproxEqual: size mismatch."; Matrix tmp(*this); @@ -1985,27 +2037,35 @@ void MatrixBase::OrthogonalizeRows() { } -// Uses Svd to compute the eigenvalue decomposition of a symmetric positive semidefinite +// Uses Svd to compute the eigenvalue decomposition of a symmetric positive +semidefinite // matrix: -// (*this) = rU * diag(rs) * rU^T, with rU an orthogonal matrix so rU^{-1} = rU^T. -// Does this by computing svd (*this) = U diag(rs) V^T ... answer is just U diag(rs) U^T. -// Throws exception if this failed to within supplied precision (typically because *this was not +// (*this) = rU * diag(rs) * rU^T, with rU an orthogonal matrix so rU^{-1} = +rU^T. +// Does this by computing svd (*this) = U diag(rs) V^T ... answer is just U +diag(rs) U^T. +// Throws exception if this failed to within supplied precision (typically +because *this was not // symmetric positive definite). template -void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase *rU, Real check_thresh) // e.g. check_thresh = 0.001 +void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase +*rU, Real check_thresh) // e.g. check_thresh = 0.001 { const MatrixIndexT D = num_rows_; KALDI_ASSERT(num_rows_ == num_cols_); - KALDI_ASSERT(IsSymmetric() && "SymPosSemiDefEig: expecting input to be symmetrical."); + KALDI_ASSERT(IsSymmetric() && "SymPosSemiDefEig: expecting input to be +symmetrical."); KALDI_ASSERT(rU->num_rows_ == D && rU->num_cols_ == D && rs->Dim() == D); Matrix Vt(D, D); Svd(rs, rU, &Vt); - // First just zero any singular values if the column of U and V do not have +ve dot product-- - // this may mean we have small negative eigenvalues, and if we zero them the result will be closer to correct. + // First just zero any singular values if the column of U and V do not have ++ve dot product-- + // this may mean we have small negative eigenvalues, and if we zero them the +result will be closer to correct. for (MatrixIndexT i = 0;i < D;i++) { Real sum = 0.0; for (MatrixIndexT j = 0;j < D;j++) sum += (*rU)(j, i) * Vt(i, j); @@ -2024,9 +2084,12 @@ void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase * if (!(old_norm == 0 && new_norm == 0)) { float diff_norm = tmpThisFull.FrobeniusNorm(); - if (std::abs(new_norm-old_norm) > old_norm*check_thresh || diff_norm > old_norm*check_thresh) { - KALDI_WARN << "SymPosSemiDefEig seems to have failed " << diff_norm << " !<< " - << check_thresh << "*" << old_norm << ", maybe matrix was not " + if (std::abs(new_norm-old_norm) > old_norm*check_thresh || diff_norm > +old_norm*check_thresh) { + KALDI_WARN << "SymPosSemiDefEig seems to have failed " << diff_norm << " +!<< " + << check_thresh << "*" << old_norm << ", maybe matrix was not +" << "positive semi definite. Continuing anyway."; } } @@ -2038,7 +2101,8 @@ template Real MatrixBase::LogDet(Real *det_sign) const { Real log_det; Matrix tmp(*this); - tmp.Invert(&log_det, det_sign, false); // false== output not needed (saves some computation). + tmp.Invert(&log_det, det_sign, false); // false== output not needed (saves +some computation). return log_det; } @@ -2052,29 +2116,29 @@ void MatrixBase::InvertDouble(Real *log_det, Real *det_sign, if (log_det) *log_det = log_det_tmp; if (det_sign) *det_sign = det_sign_tmp; } +*/ -template -void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { - mat.CopyToMat(this); -} - -template -Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { - Resize(M.NumRows(), M.NumCols(), kUndefined); - M.CopyToMat(this); -} +// template +// void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { +// mat.CopyToMat(this); +//} +// template +// Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { +// Resize(M.NumRows(), M.NumCols(), kUndefined); +// M.CopyToMat(this); +//} -template +template void MatrixBase::InvertElements() { - for (MatrixIndexT r = 0; r < num_rows_; r++) { - for (MatrixIndexT c = 0; c < num_cols_; c++) { - (*this)(r, c) = static_cast(1.0 / (*this)(r, c)); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + for (MatrixIndexT c = 0; c < num_cols_; c++) { + (*this)(r, c) = static_cast(1.0 / (*this)(r, c)); + } } - } } - +/* template void MatrixBase::Transpose() { KALDI_ASSERT(num_rows_ == num_cols_); @@ -2139,7 +2203,8 @@ void MatrixBase::Pow(const MatrixBase &src, Real power) { } template -void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool include_sign) { +void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool +include_sign) { KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; Real *row_data = data_; @@ -2148,9 +2213,9 @@ void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool incl row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col ++) { if (include_sign == true && src_row_data[col] < 0) { - row_data[col] = -pow(std::abs(src_row_data[col]), power); + row_data[col] = -pow(std::abs(src_row_data[col]), power); } else { - row_data[col] = pow(std::abs(src_row_data[col]), power); + row_data[col] = pow(std::abs(src_row_data[col]), power); } } } @@ -2165,7 +2230,8 @@ void MatrixBase::Floor(const MatrixBase &src, Real floor_val) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] < floor_val ? floor_val : src_row_data[col]); + row_data[col] = (src_row_data[col] < floor_val ? floor_val : +src_row_data[col]); } } @@ -2178,7 +2244,8 @@ void MatrixBase::Ceiling(const MatrixBase &src, Real ceiling_val) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] > ceiling_val ? ceiling_val : src_row_data[col]); + row_data[col] = (src_row_data[col] > ceiling_val ? ceiling_val : +src_row_data[col]); } } @@ -2204,12 +2271,14 @@ void MatrixBase::ExpSpecial(const MatrixBase &src) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] < Real(0) ? kaldi::Exp(src_row_data[col]) : (src_row_data[col] + Real(1))); + row_data[col] = (src_row_data[col] < Real(0) ? +kaldi::Exp(src_row_data[col]) : (src_row_data[col] + Real(1))); } } template -void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, Real upper_limit) { +void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, +Real upper_limit) { KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; Real *row_data = data_; @@ -2219,11 +2288,11 @@ void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, for (MatrixIndexT col = 0; col < num_cols; col++) { const Real x = src_row_data[col]; if (!(x >= lower_limit)) - row_data[col] = kaldi::Exp(lower_limit); + row_data[col] = kaldi::Exp(lower_limit); else if (x > upper_limit) - row_data[col] = kaldi::Exp(upper_limit); + row_data[col] = kaldi::Exp(upper_limit); else - row_data[col] = kaldi::Exp(x); + row_data[col] = kaldi::Exp(x); } } } @@ -2250,15 +2319,15 @@ bool MatrixBase::Power(Real power) { (*this).AddMatMat(1.0, tmp, kNoTrans, P, kNoTrans, 0.0); return true; } - -template +*/ +template void Matrix::Swap(Matrix *other) { - std::swap(this->data_, other->data_); - std::swap(this->num_cols_, other->num_cols_); - std::swap(this->num_rows_, other->num_rows_); - std::swap(this->stride_, other->stride_); + std::swap(this->data_, other->data_); + std::swap(this->num_cols_, other->num_cols_); + std::swap(this->num_rows_, other->num_rows_); + std::swap(this->stride_, other->stride_); } - +/* // Repeating this comment that appeared in the header: // Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D // P^{-1}. Be careful: the relationship of D to the eigenvalues we output is @@ -2269,12 +2338,14 @@ void Matrix::Swap(Matrix *other) { // be block diagonal, with 2x2 blocks corresponding to any such pairs. If a // pair is lambda +- i*mu, D will have a corresponding 2x2 block // [lambda, mu; -mu, lambda]. -// Note that if the input matrix (*this) is non-invertible, P may not be invertible +// Note that if the input matrix (*this) is non-invertible, P may not be +invertible // so in this case instead of the equation (*this) = P D P^{-1} holding, we have // instead (*this) P = P D. // // By making the pointer arguments non-NULL or NULL, the user can choose to take -// not to take the eigenvalues directly, and/or the matrix D which is block-diagonal +// not to take the eigenvalues directly, and/or the matrix D which is +block-diagonal // with 2x2 blocks. template void MatrixBase::Eig(MatrixBase *P, @@ -2298,7 +2369,7 @@ void MatrixBase::Eig(MatrixBase *P, // INT_32 mVersion; // INT_32 mSampSize; // }; - +/* template bool ReadHtk(std::istream &is, Matrix *M_ptr, HtkHeader *header_ptr) { @@ -2400,7 +2471,8 @@ template bool ReadHtk(std::istream &is, Matrix *M, HtkHeader *header_ptr); template -bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr) // header may be derived from a previous call to ReadHtk. Must be in binary mode. +bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr) // +header may be derived from a previous call to ReadHtk. Must be in binary mode. { KALDI_ASSERT(M.NumRows() == static_cast(htk_hdr.mNSamples)); KALDI_ASSERT(M.NumCols() == static_cast(htk_hdr.mSampleSize) / @@ -2502,12 +2574,14 @@ template Real TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &B, MatrixTransposeType transB, const MatrixBase &C, MatrixTransposeType transC) { - MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), BCols = B.NumCols(), + MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), +BCols = B.NumCols(), CRows = C.NumRows(), CCols = C.NumCols(); if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); if (transC == kTrans) std::swap(CRows, CCols); - KALDI_ASSERT( CCols == ARows && ACols == BRows && BCols == CRows && "TraceMatMatMat: args have mismatched dimensions."); + KALDI_ASSERT( CCols == ARows && ACols == BRows && BCols == CRows && +"TraceMatMatMat: args have mismatched dimensions."); if (ARows*BCols < std::min(BRows*CCols, CRows*ACols)) { Matrix AB(ARows, BCols); AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. @@ -2539,13 +2613,16 @@ Real TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &B, MatrixTransposeType transB, const MatrixBase &C, MatrixTransposeType transC, const MatrixBase &D, MatrixTransposeType transD) { - MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), BCols = B.NumCols(), - CRows = C.NumRows(), CCols = C.NumCols(), DRows = D.NumRows(), DCols = D.NumCols(); + MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), +BCols = B.NumCols(), + CRows = C.NumRows(), CCols = C.NumCols(), DRows = D.NumRows(), DCols = +D.NumCols(); if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); if (transC == kTrans) std::swap(CRows, CCols); if (transD == kTrans) std::swap(DRows, DCols); - KALDI_ASSERT( DCols == ARows && ACols == BRows && BCols == CRows && CCols == DRows && "TraceMatMatMat: args have mismatched dimensions."); + KALDI_ASSERT( DCols == ARows && ACols == BRows && BCols == CRows && CCols == +DRows && "TraceMatMatMat: args have mismatched dimensions."); if (ARows*BCols < std::min(BRows*CCols, std::min(CRows*DCols, DRows*ACols))) { Matrix AB(ARows, BCols); AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. @@ -2572,13 +2649,18 @@ float TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &D, MatrixTransposeType transD); template -double TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, - const MatrixBase &B, MatrixTransposeType transB, - const MatrixBase &C, MatrixTransposeType transC, - const MatrixBase &D, MatrixTransposeType transD); +double TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType +transA, + const MatrixBase &B, MatrixTransposeType +transB, + const MatrixBase &C, MatrixTransposeType +transC, + const MatrixBase &D, MatrixTransposeType +transD); template void SortSvd(VectorBase *s, MatrixBase *U, - MatrixBase *Vt, bool sort_on_absolute_value) { + MatrixBase *Vt, bool +sort_on_absolute_value) { /// Makes sure the Svd is sorted (from greatest to least absolute value). MatrixIndexT num_singval = s->Dim(); KALDI_ASSERT(U == NULL || U->NumCols() == num_singval); @@ -2620,7 +2702,8 @@ void SortSvd(VectorBase *s, MatrixBase *U, MatrixBase *Vt, bool); template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase +&im, MatrixBase *D) { MatrixIndexT n = re.Dim(); KALDI_ASSERT(im.Dim() == n && D->NumRows() == n && D->NumCols() == n); @@ -2634,7 +2717,8 @@ void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase & } else { // First of a complex pair KALDI_ASSERT(j+1 < n && ApproxEqual(im(j+1), -im(j)) && ApproxEqual(re(j+1), re(j))); - /// if (im(j) < 0.0) KALDI_WARN << "Negative first im part of pair"; // TEMP + /// if (im(j) < 0.0) KALDI_WARN << "Negative first im part of pair"; // +TEMP Real lambda = re(j), mu = im(j); // create 2x2 block [lambda, mu; -mu, lambda] (*D)(j, j) = lambda; @@ -2647,10 +2731,12 @@ void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase & } template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase +&im, MatrixBase *D); template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const +VectorBase &im, MatrixBase *D); @@ -2691,7 +2777,8 @@ bool AttemptComplexPower(double *x_re, double *x_im, double power); template Real TraceMatMat(const MatrixBase &A, const MatrixBase &B, - MatrixTransposeType trans) { // tr(A B), equivalent to sum of each element of A times same element in B' + MatrixTransposeType trans) { // tr(A B), equivalent to sum of +each element of A times same element in B' MatrixIndexT aStride = A.stride_, bStride = B.stride_; if (trans == kNoTrans) { KALDI_ASSERT(A.NumRows() == B.NumCols() && A.NumCols() == B.NumRows()); @@ -2821,33 +2908,36 @@ void MatrixBase::GroupMax(const MatrixBase &src) { } } } - -template +*/ +template void MatrixBase::CopyCols(const MatrixBase &src, const MatrixIndexT *indices) { - KALDI_ASSERT(NumRows() == src.NumRows()); - MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - this_stride = stride_, src_stride = src.stride_; - Real *this_data = this->data_; - const Real *src_data = src.data_; + KALDI_ASSERT(NumRows() == src.NumRows()); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + this_stride = stride_, src_stride = src.stride_; + Real *this_data = this->data_; + const Real *src_data = src.data_; #ifdef KALDI_PARANOID - MatrixIndexT src_cols = src.NumCols(); - for (MatrixIndexT i = 0; i < num_cols; i++) - KALDI_ASSERT(indices[i] >= -1 && indices[i] < src_cols); + MatrixIndexT src_cols = src.NumCols(); + for (MatrixIndexT i = 0; i < num_cols; i++) + KALDI_ASSERT(indices[i] >= -1 && indices[i] < src_cols); #endif - // For the sake of memory locality we do this row by row, rather - // than doing it column-wise using cublas_Xcopy - for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data += src_stride) { - const MatrixIndexT *index_ptr = &(indices[0]); - for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { - if (*index_ptr < 0) this_data[c] = 0; - else this_data[c] = src_data[*index_ptr]; + // For the sake of memory locality we do this row by row, rather + // than doing it column-wise using cublas_Xcopy + for (MatrixIndexT r = 0; r < num_rows; + r++, this_data += this_stride, src_data += src_stride) { + const MatrixIndexT *index_ptr = &(indices[0]); + for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { + if (*index_ptr < 0) + this_data[c] = 0; + else + this_data[c] = src_data[*index_ptr]; + } } - } } - +/* template void MatrixBase::AddCols(const MatrixBase &src, const MatrixIndexT *indices) { @@ -2864,15 +2954,17 @@ void MatrixBase::AddCols(const MatrixBase &src, // For the sake of memory locality we do this row by row, rather // than doing it column-wise using cublas_Xcopy - for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data += src_stride) { + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data ++= src_stride) { const MatrixIndexT *index_ptr = &(indices[0]); for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { if (*index_ptr >= 0) this_data[c] += src_data[*index_ptr]; } } -} +}*/ +/* template void MatrixBase::CopyRows(const MatrixBase &src, const MatrixIndexT *indices) { @@ -2995,7 +3087,8 @@ void MatrixBase::DiffSigmoid(const MatrixBase &value, const MatrixBase &diff) { KALDI_ASSERT(SameDim(*this, value) && SameDim(*this, diff)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - stride = stride_, value_stride = value.stride_, diff_stride = diff.stride_; + stride = stride_, value_stride = value.stride_, diff_stride = +diff.stride_; Real *data = data_; const Real *value_data = value.data_, *diff_data = diff.data_; for (MatrixIndexT r = 0; r < num_rows; r++) { @@ -3012,7 +3105,8 @@ void MatrixBase::DiffTanh(const MatrixBase &value, const MatrixBase &diff) { KALDI_ASSERT(SameDim(*this, value) && SameDim(*this, diff)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - stride = stride_, value_stride = value.stride_, diff_stride = diff.stride_; + stride = stride_, value_stride = value.stride_, diff_stride = +diff.stride_; Real *data = data_; const Real *value_data = value.data_, *diff_data = diff.data_; for (MatrixIndexT r = 0; r < num_rows; r++) { @@ -3022,12 +3116,13 @@ void MatrixBase::DiffTanh(const MatrixBase &value, value_data += value_stride; diff_data += diff_stride; } -} - +}*/ +/* template template -void MatrixBase::AddVecToRows(const Real alpha, const VectorBase &v) { +void MatrixBase::AddVecToRows(const Real alpha, const +VectorBase &v) { const MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, stride = stride_; KALDI_ASSERT(v.Dim() == num_cols); @@ -3058,7 +3153,8 @@ template void MatrixBase::AddVecToRows(const double alpha, template template -void MatrixBase::AddVecToCols(const Real alpha, const VectorBase &v) { +void MatrixBase::AddVecToCols(const Real alpha, const +VectorBase &v) { const MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, stride = stride_; KALDI_ASSERT(v.Dim() == num_rows); @@ -3087,11 +3183,11 @@ template void MatrixBase::AddVecToCols(const double alpha, const VectorBase &v); template void MatrixBase::AddVecToCols(const double alpha, const VectorBase &v); - -//Explicit instantiation of the classes -//Apparently, it seems to be necessary that the instantiation -//happens at the end of the file. Otherwise, not all the member -//functions will get instantiated. +*/ +// Explicit instantiation of the classes +// Apparently, it seems to be necessary that the instantiation +// happens at the end of the file. Otherwise, not all the member +// functions will get instantiated. template class Matrix; template class Matrix; @@ -3100,4 +3196,4 @@ template class MatrixBase; template class SubMatrix; template class SubMatrix; -} // namespace kaldi +} // namespace kaldi diff --git a/runtime/engine/common/matrix/kaldi-matrix.h b/runtime/engine/common/matrix/kaldi-matrix.h new file mode 100644 index 00000000..d614f36f --- /dev/null +++ b/runtime/engine/common/matrix/kaldi-matrix.h @@ -0,0 +1,906 @@ +// matrix/kaldi-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Haihua Xu +// 2017 Shiyin Kang +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_MATRIX_H_ +#define KALDI_MATRIX_KALDI_MATRIX_H_ 1 + +#include + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// @{ \addtogroup matrix_funcs_scalar + +/// \addtogroup matrix_group +/// @{ + +/// Base class which provides matrix operations not involving resizing +/// or allocation. Classes Matrix and SubMatrix inherit from it and take care +/// of allocation and resizing. +template +class MatrixBase { + public: + // so this child can access protected members of other instances. + friend class Matrix; + friend class SubMatrix; + // friend declarations for CUDA matrices (see ../cudamatrix/) + + /// Returns number of rows (or zero for empty matrix). + inline MatrixIndexT NumRows() const { return num_rows_; } + + /// Returns number of columns (or zero for empty matrix). + inline MatrixIndexT NumCols() const { return num_cols_; } + + /// Stride (distance in memory between each row). Will be >= NumCols. + inline MatrixIndexT Stride() const { return stride_; } + + /// Returns size in bytes of the data held by the matrix. + size_t SizeInBytes() const { + return static_cast(num_rows_) * static_cast(stride_) * + sizeof(Real); + } + + /// Gives pointer to raw data (const). + inline const Real *Data() const { return data_; } + + /// Gives pointer to raw data (non-const). + inline Real *Data() { return data_; } + + /// Returns pointer to data for one row (non-const) + inline Real *RowData(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return data_ + i * stride_; + } + + /// Returns pointer to data for one row (const) + inline const Real *RowData(MatrixIndexT i) const { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return data_ + i * stride_; + } + + /// Indexing operator, non-const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline Real &operator()(MatrixIndexT r, MatrixIndexT c) { + KALDI_PARANOID_ASSERT( + static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_cols_)); + return *(data_ + r * stride_ + c); + } + /// Indexing operator, provided for ease of debugging (gdb doesn't work + /// with parenthesis operator). + Real &Index(MatrixIndexT r, MatrixIndexT c) { return (*this)(r, c); } + + /// Indexing operator, const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline const Real operator()(MatrixIndexT r, MatrixIndexT c) const { + KALDI_PARANOID_ASSERT( + static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_cols_)); + return *(data_ + r * stride_ + c); + } + + /* Basic setting-to-special values functions. */ + + /// Sets matrix to zero. + void SetZero(); + /// Sets all elements to a specific value. + void Set(Real); + /// Sets to zero, except ones along diagonal [for non-square matrices too] + + /// Copy given matrix. (no resize is done). + template + void CopyFromMat(const MatrixBase &M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from compressed matrix. + // void CopyFromMat(const CompressedMatrix &M); + + /// Copy given tpmatrix. (no resize is done). + // template + // void CopyFromTp(const TpMatrix &M, + // MatrixTransposeType trans = kNoTrans); + + /// Copy from CUDA matrix. Implemented in ../cudamatrix/cu-matrix.h + // template + // void CopyFromMat(const CuMatrixBase &M, + // MatrixTransposeType trans = kNoTrans); + + /// This function has two modes of operation. If v.Dim() == NumRows() * + /// NumCols(), then treats the vector as a row-by-row concatenation of a + /// matrix and copies to *this. + /// if v.Dim() == NumCols(), it sets each row of *this to a copy of v. + void CopyRowsFromVec(const VectorBase &v); + + /// This version of CopyRowsFromVec is implemented in + /// ../cudamatrix/cu-vector.cc + // void CopyRowsFromVec(const CuVectorBase &v); + + template + void CopyRowsFromVec(const VectorBase &v); + + /// Copies vector into matrix, column-by-column. + /// Note that rv.Dim() must either equal NumRows()*NumCols() or NumRows(); + /// this has two modes of operation. + void CopyColsFromVec(const VectorBase &v); + + /// Copy vector into specific column of matrix. + void CopyColFromVec(const VectorBase &v, const MatrixIndexT col); + /// Copy vector into specific row of matrix. + void CopyRowFromVec(const VectorBase &v, const MatrixIndexT row); + /// Copy vector into diagonal of matrix. + void CopyDiagFromVec(const VectorBase &v); + + /* Accessing of sub-parts of the matrix. */ + + /// Return specific row of matrix [const]. + inline const SubVector Row(MatrixIndexT i) const { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return SubVector(data_ + (i * stride_), NumCols()); + } + + /// Return specific row of matrix. + inline SubVector Row(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return SubVector(data_ + (i * stride_), NumCols()); + } + + /// Return a sub-part of matrix. + inline SubMatrix Range(const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix( + *this, row_offset, num_rows, col_offset, num_cols); + } + inline SubMatrix RowRange(const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix(*this, row_offset, num_rows, 0, num_cols_); + } + inline SubMatrix ColRange(const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix(*this, 0, num_rows_, col_offset, num_cols); + } + + /* + /// Returns sum of all elements in matrix. + Real Sum() const; + /// Returns trace of matrix. + Real Trace(bool check_square = true) const; + // If check_square = true, will crash if matrix is not square. + + /// Returns maximum element of matrix. + Real Max() const; + /// Returns minimum element of matrix. + Real Min() const; + + /// Element by element multiplication with a given matrix. + void MulElements(const MatrixBase &A); + + /// Divide each element by the corresponding element of a given matrix. + void DivElements(const MatrixBase &A); + + /// Multiply each element with a scalar value. + void Scale(Real alpha); + + /// Set, element-by-element, *this = max(*this, A) + void Max(const MatrixBase &A); + /// Set, element-by-element, *this = min(*this, A) + void Min(const MatrixBase &A); + + /// Equivalent to (*this) = (*this) * diag(scale). Scaling + /// each column by a scalar taken from that dimension of the vector. + void MulColsVec(const VectorBase &scale); + + /// Equivalent to (*this) = diag(scale) * (*this). Scaling + /// each row by a scalar taken from that dimension of the vector. + void MulRowsVec(const VectorBase &scale); + + /// Divide each row into src.NumCols() equal groups, and then scale i'th + row's + /// j'th group of elements by src(i, j). Requires src.NumRows() == + /// this->NumRows() and this->NumCols() % src.NumCols() == 0. + void MulRowsGroupMat(const MatrixBase &src); + + /// Returns logdet of matrix. + Real LogDet(Real *det_sign = NULL) const; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *log_det = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + /// matrix inverse [double]. + /// if inverse_needed = false, will fill matrix with garbage + /// (only useful if logdet wanted). + /// Does inversion in double precision even if matrix was not double. + void InvertDouble(Real *LogDet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + */ + /// Inverts all the elements of the matrix + void InvertElements(); + /* + /// Transpose the matrix. This one is only + /// applicable to square matrices (the one in the + /// Matrix child class works also for non-square. + void Transpose(); + + */ + /// Copies column r from column indices[r] of src. + /// As a special case, if indexes[i] == -1, sets column i to zero. + /// all elements of "indices" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + void CopyCols(const MatrixBase &src, const MatrixIndexT *indices); + + /// Copies row r from row indices[r] of src (does nothing + /// As a special case, if indexes[i] == -1, sets row i to zero. + /// all elements of "indices" must be in [-1, src.NumRows()-1], + /// and src.NumCols() must equal this.NumCols() + void CopyRows(const MatrixBase &src, const MatrixIndexT *indices); + + /// Add column indices[r] of src to column r. + /// As a special case, if indexes[i] == -1, skip column i + /// indices.size() must equal this->NumCols(), + /// all elements of "reorder" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + // void AddCols(const MatrixBase &src, + // const MatrixIndexT *indices); + + /// Copies row r of this matrix from an array of floats at the location + /// given + /// by src[r]. If any src[r] is NULL then this.Row(r) will be set to zero. + /// Note: we are using "pointer to const pointer to const object" for "src", + /// because we may create "src" by calling Data() of const CuArray + void CopyRows(const Real *const *src); + + /// Copies row r of this matrix to the array of floats at the location given + /// by dst[r]. If dst[r] is NULL, does not copy anywhere. Requires that + /// none + /// of the memory regions pointed to by the pointers in "dst" overlap (e.g. + /// none of the pointers should be the same). + void CopyToRows(Real *const *dst) const; + + /// Does for each row r, this.Row(r) += alpha * src.row(indexes[r]). + /// If indexes[r] < 0, does not add anything. all elements of "indexes" must + /// be in [-1, src.NumRows()-1], and src.NumCols() must equal + /// this.NumCols(). + // void AddRows(Real alpha, + // const MatrixBase &src, + // const MatrixIndexT *indexes); + + /// Does for each row r, this.Row(r) += alpha * src[r], treating src[r] as + /// the + /// beginning of a region of memory representing a vector of floats, of the + /// same length as this.NumCols(). If src[r] is NULL, does not add anything. + // void AddRows(Real alpha, const Real *const *src); + + /// For each row r of this matrix, adds it (times alpha) to the array of + /// floats at the location given by dst[r]. If dst[r] is NULL, does not do + /// anything for that row. Requires that none of the memory regions pointed + /// to by the pointers in "dst" overlap (e.g. none of the pointers should be + /// the same). + // void AddToRows(Real alpha, Real *const *dst) const; + + /// For each row i of *this, adds this->Row(i) to + /// dst->Row(indexes(i)) if indexes(i) >= 0, else do nothing. + /// Requires that all the indexes[i] that are >= 0 + /// be distinct, otherwise the behavior is undefined. + // void AddToRows(Real alpha, + // const MatrixIndexT *indexes, + // MatrixBase *dst) const; + /* + inline void ApplyPow(Real power) { + this -> Pow(*this, power); + } + + + inline void ApplyPowAbs(Real power, bool include_sign=false) { + this -> PowAbs(*this, power, include_sign); + } + + inline void ApplyHeaviside() { + this -> Heaviside(*this); + } + + inline void ApplyFloor(Real floor_val) { + this -> Floor(*this, floor_val); + } + + inline void ApplyCeiling(Real ceiling_val) { + this -> Ceiling(*this, ceiling_val); + } + + inline void ApplyExp() { + this -> Exp(*this); + } + + inline void ApplyExpSpecial() { + this -> ExpSpecial(*this); + } + + inline void ApplyExpLimited(Real lower_limit, Real upper_limit) { + this -> ExpLimited(*this, lower_limit, upper_limit); + } + + inline void ApplyLog() { + this -> Log(*this); + } + */ + /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = + /// P D + /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output + /// is + /// slightly complicated, due to the need for P to be real. In the + /// symmetric + /// case D is diagonal and real, but in + /// the non-symmetric case there may be complex-conjugate pairs of + /// eigenvalues. + /// In this case, for the equation (*this) = P D P^{-1} to hold, D must + /// actually + /// be block diagonal, with 2x2 blocks corresponding to any such pairs. If + /// a + /// pair is lambda +- i*mu, D will have a corresponding 2x2 block + /// [lambda, mu; -mu, lambda]. + /// Note that if the input matrix (*this) is non-invertible, P may not be + /// invertible + /// so in this case instead of the equation (*this) = P D P^{-1} holding, we + /// have + /// instead (*this) P = P D. + /// + /// The non-member function CreateEigenvalueMatrix creates D from eigs_real + /// and eigs_imag. + // void Eig(MatrixBase *P, + // VectorBase *eigs_real, + // VectorBase *eigs_imag) const; + + /// The Power method attempts to take the matrix to a power using a method + /// that + /// works in general for fractional and negative powers. The input matrix + /// must + /// be invertible and have reasonable condition (or we don't guarantee the + /// results. The method is based on the eigenvalue decomposition. It will + /// return false and leave the matrix unchanged, if at entry the matrix had + /// real negative eigenvalues (or if it had zero eigenvalues and the power + /// was + /// negative). + // bool Power(Real pow); + + /** Singular value decomposition + Major limitations: + For nonsquare matrices, we assume m>=n (NumRows >= NumCols), and we + return + the "skinny" Svd, i.e. the matrix in the middle is diagonal, and the + one on the left is rectangular. + + In Svd, *this = U*diag(S)*Vt. + Null pointers for U and/or Vt at input mean we do not want that output. + We + expect that S.Dim() == m, U is either NULL or m by n, + and v is either NULL or n by n. + The singular values are not sorted (use SortSvd for that). */ + // void DestructiveSvd(VectorBase *s, MatrixBase *U, + // MatrixBase *Vt); // Destroys calling matrix. + + /// Compute SVD (*this) = U diag(s) Vt. Note that the V in the call is + /// already + /// transposed; the normal formulation is U diag(s) V^T. + /// Null pointers for U or V mean we don't want that output (this saves + /// compute). The singular values are not sorted (use SortSvd for that). + // void Svd(VectorBase *s, MatrixBase *U, + // MatrixBase *Vt) const; + /// Compute SVD but only retain the singular values. + // void Svd(VectorBase *s) const { Svd(s, NULL, NULL); } + + + /// Returns smallest singular value. + // Real MinSingularValue() const { + // Vector tmp(std::min(NumRows(), NumCols())); + // Svd(&tmp); + // return tmp.Min(); + //} + + // void TestUninitialized() const; // This function is designed so that if + // any element + // if the matrix is uninitialized memory, valgrind will complain. + + /// Returns condition number by computing Svd. Works even if cols > rows. + /// Returns infinity if all singular values are zero. + /* + Real Cond() const; + + /// Returns true if matrix is Symmetric. + bool IsSymmetric(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is Diagonal. + bool IsDiagonal(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if the matrix is all zeros, except for ones on diagonal. + (it + /// does not have to be square). More specifically, this function returns + /// false if for any i, j, (*this)(i, j) differs by more than cutoff from + the + /// expression (i == j ? 1 : 0). + bool IsUnit(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-05) const; // replace magic number + + /// Frobenius norm, which is the sqrt of sum of square elements. Same as + Schatten 2-norm, + /// or just "2-norm". + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() + /// <= tol * (*this).FrobeniusNorm(). + bool ApproxEqual(const MatrixBase &other, float tol = 0.01) const; + + /// Tests for exact equality. It's usually preferable to use ApproxEqual. + bool Equal(const MatrixBase &other) const; + + /// largest absolute value. + Real LargestAbsElem() const; // largest absolute value. + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, it uses a pruning beam, discarding + /// terms less than (max - prune). Note: in future + /// we may change this so that if prune = 0.0, it takes + /// the max, so use -1 if you don't want to prune. + Real LogSumExp(Real prune = -1.0) const; + + /// Apply soft-max to the collection of all elements of the + /// matrix and return normalizer (log sum of exponentials). + Real ApplySoftMax(); + + /// Set each element to the sigmoid of the corresponding element of "src". + void Sigmoid(const MatrixBase &src); + + /// Sets each element to the Heaviside step function (x > 0 ? 1 : 0) of the + /// corresponding element in "src". Note: in general you can make different + /// choices for x = 0, but for now please leave it as it (i.e. returning + zero) + /// because it affects the RectifiedLinearComponent in the neural net code. + void Heaviside(const MatrixBase &src); + + void Exp(const MatrixBase &src); + + void Pow(const MatrixBase &src, Real power); + + void Log(const MatrixBase &src); + + /// Apply power to the absolute value of each element. + /// If include_sign is true, the result will be multiplied with + /// the sign of the input value. + /// If the power is negative and the input to the power is zero, + /// The output will be set zero. If include_sign is true, it will + /// multiply the result by the sign of the input. + void PowAbs(const MatrixBase &src, Real power, bool + include_sign=false); + + void Floor(const MatrixBase &src, Real floor_val); + + void Ceiling(const MatrixBase &src, Real ceiling_val); + + /// For each element x of the matrix, set it to + /// (x < 0 ? exp(x) : x + 1). This function is used + /// in our RNNLM training. + void ExpSpecial(const MatrixBase &src); + + /// This is equivalent to running: + /// Floor(src, lower_limit); + /// Ceiling(src, upper_limit); + /// Exp(src) + void ExpLimited(const MatrixBase &src, Real lower_limit, Real + upper_limit); + + /// Set each element to y = log(1 + exp(x)) + void SoftHinge(const MatrixBase &src); + + /// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j^(power))^(1 / + p). + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % + this->NumCols() == 0. + void GroupPnorm(const MatrixBase &src, Real power); + + /// Calculate derivatives for the GroupPnorm function above... + /// if "input" is the input to the GroupPnorm function above (i.e. the "src" + variable), + /// and "output" is the result of the computation (i.e. the "this" of that + function + /// call), and *this has the same dimension as "input", then it sets each + element + /// of *this to the derivative d(output-elem)/d(input-elem) for each element + of "input", where + /// "output-elem" is whichever element of output depends on that input + element. + void GroupPnormDeriv(const MatrixBase &input, const MatrixBase + &output, + Real power); + + /// Apply the function y(i) = (max_{j = i*G}^{(i+1)*G-1} x_j + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % + this->NumCols() == 0. + void GroupMax(const MatrixBase &src); + + /// Calculate derivatives for the GroupMax function above, where + /// "input" is the input to the GroupMax function above (i.e. the "src" + variable), + /// and "output" is the result of the computation (i.e. the "this" of that + function + /// call), and *this must have the same dimension as "input". Each element + /// of *this will be set to 1 if the corresponding input equals the output + of + /// the group, and 0 otherwise. The equals the function derivative where it + is + /// defined (it's not defined where multiple inputs in the group are equal + to the output). + void GroupMaxDeriv(const MatrixBase &input, const MatrixBase + &output); + + /// Set each element to the tanh of the corresponding element of "src". + void Tanh(const MatrixBase &src); + + // Function used in backpropagating derivatives of the sigmoid function: + // element-by-element, set *this = diff * value * (1.0 - value). + void DiffSigmoid(const MatrixBase &value, + const MatrixBase &diff); + + // Function used in backpropagating derivatives of the tanh function: + // element-by-element, set *this = diff * (1.0 - value^2). + void DiffTanh(const MatrixBase &value, + const MatrixBase &diff); + */ + /** Uses Svd to compute the eigenvalue decomposition of a symmetric positive + * semi-definite matrix: (*this) = rP * diag(rS) * rP^T, with rP an + * orthogonal matrix so rP^{-1} = rP^T. Throws exception if input was not + * positive semi-definite (check_thresh controls how stringent the check is; + * set it to 2 to ensure it won't ever complain, but it will zero out + * negative + * dimensions in your matrix. + * + * Caution: if you want the eigenvalues, it may make more sense to convert + * to + * SpMatrix and use Eig() function there, which uses eigenvalue + * decomposition + * directly rather than SVD. + */ + + /// stream read. + /// Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream &in, bool binary); + /// write to stream. + void Write(std::ostream &out, bool binary) const; + + // Below is internal methods for Svd, user does not have to know about this. + protected: + /// Initializer, callable only from child. + explicit MatrixBase(Real *data, + MatrixIndexT cols, + MatrixIndexT rows, + MatrixIndexT stride) + : data_(data), num_cols_(cols), num_rows_(rows), stride_(stride) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// Initializer, callable only from child. + /// Empty initializer, for un-initialized matrix. + explicit MatrixBase() : data_(NULL) { KALDI_ASSERT_IS_FLOATING_TYPE(Real); } + + // Make sure pointers to MatrixBase cannot be deleted. + ~MatrixBase() {} + + /// A workaround that allows SubMatrix to get a pointer to non-const data + /// for const Matrix. Unfortunately C++ does not allow us to declare a + /// "public const" inheritance or anything like that, so it would require + /// a lot of work to make the SubMatrix class totally const-correct-- + /// we would have to override many of the Matrix functions. + inline Real *Data_workaround() const { return data_; } + + /// data memory area + Real *data_; + + /// these attributes store the real matrix size as it is stored in memory + /// including memalignment + MatrixIndexT num_cols_; /// < Number of columns + MatrixIndexT num_rows_; /// < Number of rows + /** True number of columns for the internal matrix. This number may differ + * from num_cols_ as memory alignment might be used. */ + MatrixIndexT stride_; + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(MatrixBase); +}; + +/// A class for storing matrices. +template +class Matrix : public MatrixBase { + public: + /// Empty constructor. + Matrix(); + + /// Basic constructor. + Matrix(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) + : MatrixBase() { + Resize(r, c, resize_type, stride_type); + } + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Matrix *other); + + /// Constructor from any MatrixBase. Can also copy with transpose. + /// Allocates new memory. + explicit Matrix(const MatrixBase &M, + MatrixTransposeType trans = kNoTrans); + + /// Same as above, but need to avoid default copy constructor. + Matrix(const Matrix &M); // (cannot make explicit) + + /// Copy constructor: as above, but from another type. + template + explicit Matrix(const MatrixBase &M, + MatrixTransposeType trans = kNoTrans); + + /// Copy constructor taking TpMatrix... + // template + // explicit Matrix(const TpMatrix & M, + // MatrixTransposeType trans = kNoTrans) : MatrixBase() { + // if (trans == kNoTrans) { + // Resize(M.NumRows(), M.NumCols(), kUndefined); + // this->CopyFromTp(M); + //} else { + // Resize(M.NumCols(), M.NumRows(), kUndefined); + // this->CopyFromTp(M, kTrans); + //} + //} + + /// read from stream. + // Unlike one in base, allows resizing. + void Read(std::istream &in, bool binary); + + /// Remove a specified row. + void RemoveRow(MatrixIndexT i); + + /// Transpose the matrix. Works for non-square + /// matrices as well as square ones. + // void Transpose(); + + /// Distructor to free matrices. + ~Matrix() { Destroy(); } + + /// Sets matrix to a specified size (zero is OK as long as both r and c are + /// zero). The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// + /// You can set stride_type to kStrideEqualNumCols to force the stride + /// to equal the number of columns; by default it is set so that the stride + /// in bytes is a multiple of 16. + /// + /// This function takes time proportional to the number of data elements. + void Resize(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride); + + /// Assignment operator that takes MatrixBase. + Matrix &operator=(const MatrixBase &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } + + /// Assignment operator. Needed for inclusion in std::vector. + Matrix &operator=(const Matrix &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } + + + private: + /// Deallocates memory and sets to empty matrix (dimension 0, 0). + void Destroy(); + + /// Init assumes the current class contents are invalid (i.e. junk or have + /// already been freed), and it sets the matrix to newly allocated memory + /// with + /// the specified number of rows and columns. r == c == 0 is acceptable. + /// The data + /// memory contents will be undefined. + void Init(const MatrixIndexT r, + const MatrixIndexT c, + const MatrixStrideType stride_type); +}; +/// @} end "addtogroup matrix_group" + +/// \addtogroup matrix_funcs_io +/// @{ + +/// A structure containing the HTK header. +/// [TODO: change the style of the variables to Kaldi-compliant] + +template +class SubMatrix : public MatrixBase { + public: + // Initialize a SubMatrix from part of a matrix; this is + // a bit like A(b:c, d:e) in Matlab. + // This initializer is against the proper semantics of "const", since + // SubMatrix can change its contents. It would be hard to implement + // a "const-safe" version of this class. + SubMatrix(const MatrixBase &T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c); // number of columns, c > 0 + + // This initializer is mostly intended for use in CuMatrix and related + // classes. Be careful! + SubMatrix(Real *data, + MatrixIndexT num_rows, + MatrixIndexT num_cols, + MatrixIndexT stride); + + ~SubMatrix() {} + + /// This type of constructor is needed for Range() to work [in Matrix base + /// class]. Cannot make it explicit. + SubMatrix(const SubMatrix &other) + : MatrixBase( + other.data_, other.num_cols_, other.num_rows_, other.stride_) {} + + private: + /// Disallow assignment. + SubMatrix &operator=(const SubMatrix &other); +}; + +/// @} End of "addtogroup matrix_funcs_io". + +/// \addtogroup matrix_funcs_scalar +/// @{ + +// Some declarations. These are traces of products. + +/************************ +template +bool ApproxEqual(const MatrixBase &A, + const MatrixBase &B, Real tol = 0.01) { + return A.ApproxEqual(B, tol); +} + +template +inline void AssertEqual(const MatrixBase &A, const MatrixBase &B, + float tol = 0.01) { + KALDI_ASSERT(A.ApproxEqual(B, tol)); +} + +/// Returns trace of matrix. +template +double TraceMat(const MatrixBase &A) { return A.Trace(); } + + +/// Returns tr(A B C) +template +Real TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC); + +/// Returns tr(A B C D) +template +Real TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &D, MatrixTransposeType transD); + +/// @} end "addtogroup matrix_funcs_scalar" + + +/// \addtogroup matrix_funcs_misc +/// @{ + + +/// Function to ensure that SVD is sorted. This function is made as generic as +/// possible, to be applicable to other types of problems. s->Dim() should be +/// the same as U->NumCols(), and we sort s from greatest to least absolute +/// value (if sort_on_absolute_value == true) or greatest to least value +/// otherwise, moving the columns of U, if it exists, and the rows of Vt, if it +/// exists, around in the same way. Note: the "absolute value" part won't +matter +/// if this is an actual SVD, since singular values are non-negative. +template void SortSvd(VectorBase *s, MatrixBase *U, + MatrixBase* Vt = NULL, + bool sort_on_absolute_value = true); + +/// Creates the eigenvalue matrix D that is part of the decomposition used +Matrix::Eig. +/// D will be block-diagonal with blocks of size 1 (for real eigenvalues) or 2x2 +/// for complex pairs. If a complex pair is lambda +- i*mu, D will have a +corresponding +/// 2x2 block [lambda, mu; -mu, lambda]. +/// This function will throw if any complex eigenvalues are not in complex +conjugate +/// pairs (or the members of such pairs are not consecutively numbered). +template +void CreateEigenvalueMatrix(const VectorBase &real, const VectorBase +&imag, + MatrixBase *D); + +/// The following function is used in Matrix::Power, and separately tested, so +we +/// declare it here mainly for the testing code to see. It takes a complex +value to +/// a power using a method that will work for noninteger powers (but will fail +if the +/// complex value is real and negative). +template +bool AttemptComplexPower(Real *x_re, Real *x_im, Real power); + +**********/ + +/// @} end of addtogroup matrix_funcs_misc + +/// \addtogroup matrix_funcs_io +/// @{ +template +std::ostream &operator<<(std::ostream &Out, const MatrixBase &M); + +template +std::istream &operator>>(std::istream &In, MatrixBase &M); + +// The Matrix read allows resizing, so we override the MatrixBase one. +template +std::istream &operator>>(std::istream &In, Matrix &M); + +template +bool SameDim(const MatrixBase &M, const MatrixBase &N) { + return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); +} + +/// @} end of \addtogroup matrix_funcs_io + + +} // namespace kaldi + + +// we need to include the implementation and some +// template specializations. +#include "matrix/kaldi-matrix-inl.h" + + +#endif // KALDI_MATRIX_KALDI_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h b/runtime/engine/common/matrix/kaldi-vector-inl.h similarity index 63% rename from speechx/speechx/kaldi/matrix/kaldi-vector-inl.h rename to runtime/engine/common/matrix/kaldi-vector-inl.h index c3a4f52f..b3075e59 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h +++ b/runtime/engine/common/matrix/kaldi-vector-inl.h @@ -26,32 +26,33 @@ namespace kaldi { -template -std::ostream & operator << (std::ostream &os, const VectorBase &rv) { - rv.Write(os, false); - return os; +template +std::ostream &operator<<(std::ostream &os, const VectorBase &rv) { + rv.Write(os, false); + return os; } -template -std::istream &operator >> (std::istream &is, VectorBase &rv) { - rv.Read(is, false); - return is; +template +std::istream &operator>>(std::istream &is, VectorBase &rv) { + rv.Read(is, false); + return is; } -template -std::istream &operator >> (std::istream &is, Vector &rv) { - rv.Read(is, false); - return is; +template +std::istream &operator>>(std::istream &is, Vector &rv) { + rv.Read(is, false); + return is; } -template<> -template<> -void VectorBase::AddVec(const float alpha, const VectorBase &rv); +// template<> +// template<> +// void VectorBase::AddVec(const float alpha, const VectorBase +// &rv); -template<> -template<> -void VectorBase::AddVec(const double alpha, - const VectorBase &rv); +// template<> +// template<> +// void VectorBase::AddVec(const double alpha, +// const VectorBase &rv); } // namespace kaldi diff --git a/runtime/engine/common/matrix/kaldi-vector.cc b/runtime/engine/common/matrix/kaldi-vector.cc new file mode 100644 index 00000000..3ab9a7ff --- /dev/null +++ b/runtime/engine/common/matrix/kaldi-vector.cc @@ -0,0 +1,1239 @@ +// matrix/kaldi-vector.cc + +// Copyright 2009-2011 Microsoft Corporation; Lukas Burget; +// Saarland University; Go Vivace Inc.; Ariya Rastrow; +// Petr Schwarz; Yanmin Qian; Jan Silovsky; +// Haihua Xu; Wei Shi +// 2015 Guoguo Chen +// 2017 Daniel Galvez +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "matrix/kaldi-vector.h" + +#include +#include + +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +template +inline void Vector::Init(const MatrixIndexT dim) { + KALDI_ASSERT(dim >= 0); + if (dim == 0) { + this->dim_ = 0; + this->data_ = NULL; + return; + } + MatrixIndexT size; + void *data; + void *free_data; + + size = dim * sizeof(Real); + + if ((data = KALDI_MEMALIGN(16, size, &free_data)) != NULL) { + this->data_ = static_cast(data); + this->dim_ = dim; + } else { + throw std::bad_alloc(); + } +} + + +template +void Vector::Resize(const MatrixIndexT dim, + MatrixResizeType resize_type) { + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || dim == 0) + resize_type = kSetZero; // nothing to copy. + else if (this->dim_ == dim) { + return; + } // nothing to do. + else { + // set tmp to a vector of the desired size. + Vector tmp(dim, kUndefined); + if (dim > this->dim_) { + memcpy(tmp.data_, this->data_, sizeof(Real) * this->dim_); + memset(tmp.data_ + this->dim_, + 0, + sizeof(Real) * (dim - this->dim_)); + } else { + memcpy(tmp.data_, this->data_, sizeof(Real) * dim); + } + tmp.Swap(this); + // and now let tmp go out of scope, deleting what was in *this. + return; + } + } + // At this point, resize_type == kSetZero or kUndefined. + + if (this->data_ != NULL) { + if (this->dim_ == dim) { + if (resize_type == kSetZero) this->SetZero(); + return; + } else { + Destroy(); + } + } + Init(dim); + if (resize_type == kSetZero) this->SetZero(); +} + + +/// Copy data from another vector +template +void VectorBase::CopyFromVec(const VectorBase &v) { + KALDI_ASSERT(Dim() == v.Dim()); + if (data_ != v.data_) { + std::memcpy(this->data_, v.data_, dim_ * sizeof(Real)); + } +} + +/* +template +template +void VectorBase::CopyFromPacked(const PackedMatrix& M) { + SubVector v(M); + this->CopyFromVec(v); +} +// instantiate the template. +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); + +/// Load data into the vector +template +void VectorBase::CopyFromPtr(const Real *data, MatrixIndexT sz) { + KALDI_ASSERT(dim_ == sz); + std::memcpy(this->data_, data, Dim() * sizeof(Real)); +}*/ + +template +template +void VectorBase::CopyFromVec(const VectorBase &other) { + KALDI_ASSERT(dim_ == other.Dim()); + Real *__restrict__ ptr = data_; + const OtherReal *__restrict__ other_ptr = other.Data(); + for (MatrixIndexT i = 0; i < dim_; i++) ptr[i] = other_ptr[i]; +} + +template void VectorBase::CopyFromVec(const VectorBase &other); +template void VectorBase::CopyFromVec(const VectorBase &other); + +// Remove element from the vector. The vector is not reallocated +template +void Vector::RemoveElement(MatrixIndexT i) { + KALDI_ASSERT(i < this->dim_ && "Access out of vector"); + for (MatrixIndexT j = i + 1; j < this->dim_; j++) + this->data_[j - 1] = this->data_[j]; + this->dim_--; +} + + +/// Deallocates memory and sets object to empty vector. +template +void Vector::Destroy() { + /// we need to free the data block if it was defined + if (this->data_ != NULL) KALDI_MEMALIGN_FREE(this->data_); + this->data_ = NULL; + this->dim_ = 0; +} + +template +void VectorBase::SetZero() { + std::memset(data_, 0, dim_ * sizeof(Real)); +} + +template +bool VectorBase::IsZero(Real cutoff) const { + Real abs_max = 0.0; + for (MatrixIndexT i = 0; i < Dim(); i++) + abs_max = std::max(std::abs(data_[i]), abs_max); + return (abs_max <= cutoff); +} + +/* +template +void VectorBase::SetRandn() { + kaldi::RandomState rstate; + MatrixIndexT last = (Dim() % 2 == 1) ? Dim() - 1 : Dim(); + for (MatrixIndexT i = 0; i < last; i += 2) { + kaldi::RandGauss2(data_ + i, data_ + i + 1, &rstate); + } + if (Dim() != last) data_[last] = static_cast(kaldi::RandGauss(&rstate)); +} + +template +void VectorBase::SetRandUniform() { + kaldi::RandomState rstate; + for (MatrixIndexT i = 0; i < Dim(); i++) { + *(data_+i) = RandUniform(&rstate); + } +} + +template +MatrixIndexT VectorBase::RandCategorical() const { + kaldi::RandomState rstate; + Real sum = this->Sum(); + KALDI_ASSERT(this->Min() >= 0.0 && sum > 0.0); + Real r = RandUniform(&rstate) * sum; + Real *data = this->data_; + MatrixIndexT dim = this->dim_; + Real running_sum = 0.0; + for (MatrixIndexT i = 0; i < dim; i++) { + running_sum += data[i]; + if (r < running_sum) return i; + } + return dim_ - 1; // Should only happen if RandUniform() + // returns exactly 1, or due to roundoff. +}*/ + +template +void VectorBase::Set(Real f) { + // Why not use memset here? + // The basic unit of memset is a byte. + // If f != 0 and sizeof(Real) > 1, then we cannot use memset. + if (f == 0) { + this->SetZero(); // calls std::memset + } else { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = f; + } + } +} + +template +void VectorBase::CopyRowsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + + Real *inc_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(); + + if (mat.Stride() == mat.NumCols()) { + memcpy(inc_data, mat.Data(), cols * rows * sizeof(Real)); + } else { + for (MatrixIndexT i = 0; i < rows; i++) { + // copy the data to the propper position + memcpy(inc_data, mat.RowData(i), cols * sizeof(Real)); + // set new copy position + inc_data += cols; + } + } +} + +template +template +void VectorBase::CopyRowsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + Real *vec_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(); + + for (MatrixIndexT i = 0; i < rows; i++) { + const OtherReal *mat_row = mat.RowData(i); + for (MatrixIndexT j = 0; j < cols; j++) { + vec_data[j] = static_cast(mat_row[j]); + } + vec_data += cols; + } +} + +template void VectorBase::CopyRowsFromMat(const MatrixBase &mat); +template void VectorBase::CopyRowsFromMat(const MatrixBase &mat); + + +template +void VectorBase::CopyColsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + + Real *inc_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(), + stride = mat.Stride(); + const Real *mat_inc_data = mat.Data(); + + for (MatrixIndexT i = 0; i < cols; i++) { + for (MatrixIndexT j = 0; j < rows; j++) { + inc_data[j] = mat_inc_data[j * stride]; + } + mat_inc_data++; + inc_data += rows; + } +} + +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row) { + KALDI_ASSERT(row < mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols()); + const Real *mat_row = mat.RowData(row); + memcpy(data_, mat_row, sizeof(Real) * dim_); +} + +template +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row) { + KALDI_ASSERT(row < mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols()); + const OtherReal *mat_row = mat.RowData(row); + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = static_cast(mat_row[i]); +} + +template void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row); +template void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row); + +/* +template +template +void VectorBase::CopyRowFromSp(const SpMatrix &sp, MatrixIndexT +row) { + KALDI_ASSERT(row < sp.NumRows()); + KALDI_ASSERT(dim_ == sp.NumCols()); + + const OtherReal *sp_data = sp.Data(); + + sp_data += (row*(row+1)) / 2; // takes us to beginning of this row. + MatrixIndexT i; + for (i = 0; i < row; i++) // copy consecutive elements. + data_[i] = static_cast(*(sp_data++)); + for(; i < dim_; ++i, sp_data += i) + data_[i] = static_cast(*sp_data); +} + +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); + +// takes absolute value of the elements to a power. +// Throws exception if could not (but only for power != 1 and power != 2). +template +void VectorBase::ApplyPowAbs(Real power, bool include_sign) { + if (power == 1.0) + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * std::abs(data_[i]); + if (power == 2.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * data_[i] * data_[i]; + } else if (power == 0.5) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * +std::sqrt(std::abs(data_[i])); + } + } else if (power < 0.0) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (data_[i] == 0.0 ? 0.0 : pow(std::abs(data_[i]), power)); + data_[i] *= (include_sign && data_[i] < 0 ? -1 : 1); + if (data_[i] == HUGE_VAL) { // HUGE_VAL is what errno returns on error. + KALDI_ERR << "Could not raise element " << i << "to power " + << power << ": returned value = " << data_[i]; + } + } + } else { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * +pow(std::abs(data_[i]), power); + if (data_[i] == HUGE_VAL) { // HUGE_VAL is what errno returns on error. + KALDI_ERR << "Could not raise element " << i << "to power " + << power << ": returned value = " << data_[i]; + } + } + } +} + +// Computes the p-th norm. Throws exception if could not. +template +Real VectorBase::Norm(Real p) const { + KALDI_ASSERT(p >= 0.0); + Real sum = 0.0; + if (p == 0.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + if (data_[i] != 0.0) sum += 1.0; + return sum; + } else if (p == 1.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + sum += std::abs(data_[i]); + return sum; + } else if (p == 2.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + sum += data_[i] * data_[i]; + return std::sqrt(sum); + } else if (p == std::numeric_limits::infinity()){ + for (MatrixIndexT i = 0; i < dim_; i++) + sum = std::max(sum, std::abs(data_[i])); + return sum; + } else { + Real tmp; + bool ok = true; + for (MatrixIndexT i = 0; i < dim_; i++) { + tmp = pow(std::abs(data_[i]), p); + if (tmp == HUGE_VAL) // HUGE_VAL is what pow returns on error. + ok = false; + sum += tmp; + } + tmp = pow(sum, static_cast(1.0/p)); + KALDI_ASSERT(tmp != HUGE_VAL); // should not happen here. + if (ok) { + return tmp; + } else { + Real maximum = this->Max(), minimum = this->Min(), + max_abs = std::max(maximum, -minimum); + KALDI_ASSERT(max_abs > 0); // Or should not have reached here. + Vector tmp(*this); + tmp.Scale(1.0 / max_abs); + return tmp.Norm(p) * max_abs; + } + } +} + +template +bool VectorBase::ApproxEqual(const VectorBase &other, float tol) +const { + if (dim_ != other.dim_) KALDI_ERR << "ApproxEqual: size mismatch " + << dim_ << " vs. " << other.dim_; + KALDI_ASSERT(tol >= 0.0); + if (tol != 0.0) { + Vector tmp(*this); + tmp.AddVec(-1.0, other); + return (tmp.Norm(2.0) <= static_cast(tol) * this->Norm(2.0)); + } else { // Test for exact equality. + const Real *data = data_; + const Real *other_data = other.data_; + for (MatrixIndexT dim = dim_, i = 0; i < dim; i++) + if (data[i] != other_data[i]) return false; + return true; + } +} + +template +Real VectorBase::Max() const { + Real ans = - std::numeric_limits::infinity(); + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 > ans || a2 > ans || a3 > ans || a4 > ans) { + Real b1 = (a1 > a2 ? a1 : a2), b2 = (a3 > a4 ? a3 : a4); + if (b1 > ans) ans = b1; + if (b2 > ans) ans = b2; + } + } + for (; i < dim; i++) + if (data[i] > ans) ans = data[i]; + return ans; +} + +template +Real VectorBase::Max(MatrixIndexT *index_out) const { + if (dim_ == 0) KALDI_ERR << "Empty vector"; + Real ans = - std::numeric_limits::infinity(); + MatrixIndexT index = 0; + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 > ans || a2 > ans || a3 > ans || a4 > ans) { + if (a1 > ans) { ans = a1; index = i; } + if (a2 > ans) { ans = a2; index = i + 1; } + if (a3 > ans) { ans = a3; index = i + 2; } + if (a4 > ans) { ans = a4; index = i + 3; } + } + } + for (; i < dim; i++) + if (data[i] > ans) { ans = data[i]; index = i; } + *index_out = index; + return ans; +} + +template +Real VectorBase::Min() const { + Real ans = std::numeric_limits::infinity(); + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 < ans || a2 < ans || a3 < ans || a4 < ans) { + Real b1 = (a1 < a2 ? a1 : a2), b2 = (a3 < a4 ? a3 : a4); + if (b1 < ans) ans = b1; + if (b2 < ans) ans = b2; + } + } + for (; i < dim; i++) + if (data[i] < ans) ans = data[i]; + return ans; +} + +template +Real VectorBase::Min(MatrixIndexT *index_out) const { + if (dim_ == 0) KALDI_ERR << "Empty vector"; + Real ans = std::numeric_limits::infinity(); + MatrixIndexT index = 0; + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 < ans || a2 < ans || a3 < ans || a4 < ans) { + if (a1 < ans) { ans = a1; index = i; } + if (a2 < ans) { ans = a2; index = i + 1; } + if (a3 < ans) { ans = a3; index = i + 2; } + if (a4 < ans) { ans = a4; index = i + 3; } + } + } + for (; i < dim; i++) + if (data[i] < ans) { ans = data[i]; index = i; } + *index_out = index; + return ans; +}*/ + + +template +template +void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col) { + KALDI_ASSERT(col < mat.NumCols()); + KALDI_ASSERT(dim_ == mat.NumRows()); + for (MatrixIndexT i = 0; i < dim_; i++) data_[i] = mat(i, col); + // can't do this very efficiently so don't really bother. could improve this + // though. +} +// instantiate the template above. +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); + +// template +// void VectorBase::CopyDiagFromMat(const MatrixBase &M) { +// KALDI_ASSERT(dim_ == std::min(M.NumRows(), M.NumCols())); +// cblas_Xcopy(dim_, M.Data(), M.Stride() + 1, data_, 1); +//} + +// template +// void VectorBase::CopyDiagFromPacked(const PackedMatrix &M) { +// KALDI_ASSERT(dim_ == M.NumCols()); +// for (MatrixIndexT i = 0; i < dim_; i++) +// data_[i] = M(i, i); +//// could make this more efficient. +//} + +// template +// Real VectorBase::Sum() const { +//// Do a dot-product with a size-1 array with a stride of 0 to +//// implement sum. This allows us to access SIMD operations in a +//// cross-platform way via your BLAS library. +// Real one(1); +// return cblas_Xdot(dim_, data_, 1, &one, 0); +//} + +// template +// Real VectorBase::SumLog() const { +// double sum_log = 0.0; +// double prod = 1.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// prod *= data_[i]; +//// Possible future work (arnab): change these magic values to pre-defined +//// constants +// if (prod < 1.0e-10 || prod > 1.0e+10) { +// sum_log += Log(prod); +// prod = 1.0; +//} +//} +// if (prod != 1.0) sum_log += Log(prod); +// return sum_log; +//} + +// template +// void VectorBase::AddRowSumMat(Real alpha, const MatrixBase &M, +// Real beta) { +// KALDI_ASSERT(dim_ == M.NumCols()); +// MatrixIndexT num_rows = M.NumRows(), stride = M.Stride(), dim = dim_; +// Real *data = data_; + +//// implement the function according to a dimension cutoff for computation +/// efficiency +// if (num_rows <= 64) { +// cblas_Xscal(dim, beta, data, 1); +// const Real *m_data = M.Data(); +// for (MatrixIndexT i = 0; i < num_rows; i++, m_data += stride) +// cblas_Xaxpy(dim, alpha, m_data, 1, data, 1); + +//} else { +// Vector ones(M.NumRows()); +// ones.Set(1.0); +// this->AddMatVec(alpha, M, kTrans, ones, beta); +//} +//} + +// template +// void VectorBase::AddColSumMat(Real alpha, const MatrixBase &M, +// Real beta) { +// KALDI_ASSERT(dim_ == M.NumRows()); +// MatrixIndexT num_cols = M.NumCols(); + +//// implement the function according to a dimension cutoff for computation +/// efficiency +// if (num_cols <= 64) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// double sum = 0.0; +// const Real *src = M.RowData(i); +// for (MatrixIndexT j = 0; j < num_cols; j++) +// sum += src[j]; +// data_[i] = alpha * sum + beta * data_[i]; +//} +//} else { +// Vector ones(M.NumCols()); +// ones.Set(1.0); +// this->AddMatVec(alpha, M, kNoTrans, ones, beta); +//} +//} + +// template +// Real VectorBase::LogSumExp(Real prune) const { +// Real sum; +// if (sizeof(sum) == 8) sum = kLogZeroDouble; +// else sum = kLogZeroFloat; +// Real max_elem = Max(), cutoff; +// if (sizeof(Real) == 4) cutoff = max_elem + kMinLogDiffFloat; +// else cutoff = max_elem + kMinLogDiffDouble; +// if (prune > 0.0 && max_elem - prune > cutoff) // explicit pruning... +// cutoff = max_elem - prune; + +// double sum_relto_max_elem = 0.0; + +// for (MatrixIndexT i = 0; i < dim_; i++) { +// BaseFloat f = data_[i]; +// if (f >= cutoff) +// sum_relto_max_elem += Exp(f - max_elem); +//} +// return max_elem + Log(sum_relto_max_elem); +//} + +// template +// void VectorBase::InvertElements() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = static_cast(1 / data_[i]); +//} +//} + +// template +// void VectorBase::ApplyLog() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (data_[i] < 0.0) +// KALDI_ERR << "Trying to take log of a negative number."; +// data_[i] = Log(data_[i]); +//} +//} + +// template +// void VectorBase::ApplyLogAndCopy(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = Log(v(i)); +//} +//} + +// template +// void VectorBase::ApplyExp() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = Exp(data_[i]); +//} +//} + +// template +// void VectorBase::ApplyAbs() { +// for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = std::abs(data_[i]); } +//} + +// template +// void VectorBase::Floor(const VectorBase &v, Real floor_val, +// MatrixIndexT *floored_count) { +// KALDI_ASSERT(dim_ == v.dim_); +// if (floored_count == nullptr) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = std::max(v.data_[i], floor_val); +//} +//} else { +// MatrixIndexT num_floored = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (v.data_[i] < floor_val) { +// data_[i] = floor_val; +// num_floored++; +//} else { +// data_[i] = v.data_[i]; +//} +//} +//*floored_count = num_floored; +//} +//} + +// template +// void VectorBase::Ceiling(const VectorBase &v, Real ceil_val, +// MatrixIndexT *ceiled_count) { +// KALDI_ASSERT(dim_ == v.dim_); +// if (ceiled_count == nullptr) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = std::min(v.data_[i], ceil_val); +//} +//} else { +// MatrixIndexT num_changed = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (v.data_[i] > ceil_val) { +// data_[i] = ceil_val; +// num_changed++; +//} else { +// data_[i] = v.data_[i]; +//} +//} +//*ceiled_count = num_changed; +//} +//} + +// template +// MatrixIndexT VectorBase::ApplyFloor(const VectorBase &floor_vec) +// { +// KALDI_ASSERT(floor_vec.Dim() == dim_); +// MatrixIndexT num_floored = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (data_[i] < floor_vec(i)) { +// data_[i] = floor_vec(i); +// num_floored++; +//} +//} +// return num_floored; +//} + +// template +// Real VectorBase::ApplySoftMax() { +// Real max = this->Max(), sum = 0.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// sum += (data_[i] = Exp(data_[i] - max)); +//} +// this->Scale(1.0 / sum); +// return max + Log(sum); +//} + +// template +// Real VectorBase::ApplyLogSoftMax() { +// Real max = this->Max(), sum = 0.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// sum += Exp((data_[i] -= max)); +//} +// sum = Log(sum); +// this->Add(-1.0 * sum); +// return max + sum; +//} + +//#ifdef HAVE_MKL +// template<> +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// vsTanh(dim_, src.data_, data_); +//} +// template<> +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// vdTanh(dim_, src.data_, data_); +//} +//#else +// template +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// Real x = src.data_[i]; +// if (x > 0.0) { +// Real inv_expx = Exp(-x); +// x = -1.0 + 2.0 / (1.0 + inv_expx * inv_expx); +//} else { +// Real expx = Exp(x); +// x = 1.0 - 2.0 / (1.0 + expx * expx); +//} +// data_[i] = x; +//} +//} +//#endif + +//#ifdef HAVE_MKL +//// Implementing sigmoid based on tanh. +// template<> +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// this->CopyFromVec(src); +// this->Scale(0.5); +// vsTanh(dim_, data_, data_); +// this->Add(1.0); +// this->Scale(0.5); +//} +// template<> +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// this->CopyFromVec(src); +// this->Scale(0.5); +// vdTanh(dim_, data_, data_); +// this->Add(1.0); +// this->Scale(0.5); +//} +//#else +// template +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// Real x = src.data_[i]; +//// We aim to avoid floating-point overflow here. +// if (x > 0.0) { +// x = 1.0 / (1.0 + Exp(-x)); +//} else { +// Real ex = Exp(x); +// x = ex / (ex + 1.0); +//} +// data_[i] = x; +//} +//} +//#endif + + +// template +// void VectorBase::Add(Real c) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] += c; +//} +//} + +// template +// void VectorBase::Scale(Real alpha) { +// cblas_Xscal(dim_, alpha, data_, 1); +//} + +// template +// void VectorBase::MulElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] *= v.data_[i]; +//} +//} + +// template // Set each element to y = (x == orig ? changed : +// x). +// void VectorBase::ReplaceValue(Real orig, Real changed) { +// Real *data = data_; +// for (MatrixIndexT i = 0; i < dim_; i++) +// if (data[i] == orig) data[i] = changed; +//} + + +// template +// template +// void VectorBase::MulElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// const OtherReal *other_ptr = v.Data(); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] *= other_ptr[i]; +//} +//} +//// instantiate template. +// template +// void VectorBase::MulElements(const VectorBase &v); +// template +// void VectorBase::MulElements(const VectorBase &v); + + +// template +// void VectorBase::AddVecVec(Real alpha, const VectorBase &v, +// const VectorBase &r, Real beta) { +// KALDI_ASSERT(v.data_ != this->data_ && r.data_ != this->data_); +//// We pretend that v is a band-diagonal matrix. +// KALDI_ASSERT(dim_ == v.dim_ && dim_ == r.dim_); +// cblas_Xgbmv(kNoTrans, dim_, dim_, 0, 0, alpha, v.data_, 1, +// r.data_, 1, beta, this->data_, 1); +//} + + +// template +// void VectorBase::DivElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] /= v.data_[i]; +//} +//} + +// template +// template +// void VectorBase::DivElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// const OtherReal *other_ptr = v.Data(); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] /= other_ptr[i]; +//} +//} +//// instantiate template. +// template +// void VectorBase::DivElements(const VectorBase &v); +// template +// void VectorBase::DivElements(const VectorBase &v); + +// template +// void VectorBase::AddVecDivVec(Real alpha, const VectorBase &v, +// const VectorBase &rr, Real beta) { +// KALDI_ASSERT((dim_ == v.dim_ && dim_ == rr.dim_)); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = alpha * v.data_[i]/rr.data_[i] + beta * data_[i] ; +//} +//} + +// template +// template +// void VectorBase::AddVec(const Real alpha, const VectorBase +// &v) { +// KALDI_ASSERT(dim_ == v.dim_); +//// remove __restrict__ if it causes compilation problems. +// Real *__restrict__ data = data_; +// OtherReal *__restrict__ other_data = v.data_; +// MatrixIndexT dim = dim_; +// if (alpha != 1.0) +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += alpha * other_data[i]; +// else +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += other_data[i]; +//} + +// template +// void VectorBase::AddVec(const float alpha, const VectorBase +// &v); +// template +// void VectorBase::AddVec(const double alpha, const VectorBase +// &v); + +// template +// template +// void VectorBase::AddVec2(const Real alpha, const VectorBase +// &v) { +// KALDI_ASSERT(dim_ == v.dim_); +//// remove __restrict__ if it causes compilation problems. +// Real *__restrict__ data = data_; +// OtherReal *__restrict__ other_data = v.data_; +// MatrixIndexT dim = dim_; +// if (alpha != 1.0) +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += alpha * other_data[i] * other_data[i]; +// else +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += other_data[i] * other_data[i]; +//} + +// template +// void VectorBase::AddVec2(const float alpha, const VectorBase +// &v); +// template +// void VectorBase::AddVec2(const double alpha, const VectorBase +// &v); + + +template +void VectorBase::Read(std::istream &is, bool binary) { + // In order to avoid rewriting this, we just declare a Vector and + // use it to read the data, then copy. + Vector tmp; + tmp.Read(is, binary); + if (tmp.Dim() != Dim()) + KALDI_ERR << "VectorBase::Read, size mismatch " << Dim() + << " vs. " << tmp.Dim(); + CopyFromVec(tmp); +} + + +template +void Vector::Read(std::istream &is, bool binary) { + std::ostringstream specific_error; + MatrixIndexT pos_at_start = is.tellg(); + + if (binary) { + int peekval = Peek(is, binary); + const char *my_token = (sizeof(Real) == 4 ? "FV" : "DV"); + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + if (peekval == other_token_start) { // need to instantiate the other + // type to read it. + typedef typename OtherReal::Real OtherType; // if Real == + // float, + // OtherType == + // double, and + // vice versa. + Vector other(this->Dim()); + other.Read(is, binary); // add is false at this point. + if (this->Dim() != other.Dim()) this->Resize(other.Dim()); + this->CopyFromVec(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if (token.length() > 20) token = token.substr(0, 17) + "..."; + specific_error << ": Expected token " << my_token << ", got " + << token; + goto bad; + } + int32 size; + ReadBasicType(is, binary, &size); // throws on error. + if ((MatrixIndexT)size != this->Dim()) this->Resize(size); + if (size > 0) + is.read(reinterpret_cast(this->data_), sizeof(Real) * size); + if (is.fail()) { + specific_error + << "Error reading vector data (binary mode); truncated " + "stream? (size = " + << size << ")"; + goto bad; + } + return; + } else { // Text mode reading; format is " [ 1.1 2.0 3.4 ]\n" + std::string s; + is >> s; + // if ((s.compare("DV") == 0) || (s.compare("FV") == 0)) { // Back + // compatibility. + // is >> s; // get dimension + // is >> s; // get "[" + // } + if (is.fail()) { + specific_error << "EOF while trying to read vector."; + goto bad; + } + if (s.compare("[]") == 0) { + Resize(0); + return; + } // tolerate this variant. + if (s.compare("[")) { + if (s.length() > 20) s = s.substr(0, 17) + "..."; + specific_error << "Expected \"[\" but got " << s; + goto bad; + } + std::vector data; + while (1) { + int i = is.peek(); + if (i == '-' || (i >= '0' && i <= '9')) { // common cases first. + Real r; + is >> r; + if (is.fail()) { + specific_error << "Failed to read number."; + goto bad; + } + if (!std::isspace(is.peek()) && is.peek() != ']') { + specific_error << "Expected whitespace after number."; + goto bad; + } + data.push_back(r); + // But don't eat whitespace... we want to check that it's not + // newlines + // which would be valid only for a matrix. + } else if (i == ' ' || i == '\t') { + is.get(); + } else if (i == ']') { + is.get(); // eat the ']' + this->Resize(data.size()); + for (size_t j = 0; j < data.size(); j++) + this->data_[j] = data[j]; + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + } else if (static_cast(i) == '\n') { + is.get(); + } // get \n (must eat what we wrote) + if (is.fail()) { + KALDI_WARN << "After end of vector data, read error."; + // we got the data we needed, so just warn for this error. + } + return; // success. + } else if (i == -1) { + specific_error << "EOF while reading vector data."; + goto bad; + } else if (i == '\n' || i == '\r') { + specific_error << "Newline found while reading vector (maybe " + "it's a matrix?)"; + goto bad; + } else { + is >> s; // read string. + if (!KALDI_STRCASECMP(s.c_str(), "inf") || + !KALDI_STRCASECMP(s.c_str(), "infinity")) { + data.push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into vector."; + } else if (!KALDI_STRCASECMP(s.c_str(), "nan")) { + data.push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into vector."; + } else { + if (s.length() > 20) s = s.substr(0, 17) + "..."; + specific_error << "Expecting numeric vector data, got " + << s; + goto bad; + } + } + } + } +// we never reach this line (the while loop returns directly). +bad: + KALDI_ERR << "Failed to read vector from stream. " << specific_error.str() + << " File position at start is " << pos_at_start << ", currently " + << is.tellg(); +} + + +template +void VectorBase::Write(std::ostream &os, bool binary) const { + if (!os.good()) { + KALDI_ERR << "Failed to write vector to stream: stream not good"; + } + if (binary) { + std::string my_token = (sizeof(Real) == 4 ? "FV" : "DV"); + WriteToken(os, binary, my_token); + + int32 size = Dim(); // make the size 32-bit on disk. + KALDI_ASSERT(Dim() == (MatrixIndexT)size); + WriteBasicType(os, binary, size); + os.write(reinterpret_cast(Data()), sizeof(Real) * size); + } else { + os << " [ "; + for (MatrixIndexT i = 0; i < Dim(); i++) os << (*this)(i) << " "; + os << "]\n"; + } + if (!os.good()) KALDI_ERR << "Failed to write vector to stream"; +} + + +// template +// void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) +// data_[i] += alpha * v.data_[i] * v.data_[i]; +//} + +//// this <-- beta*this + alpha*M*v. +// template +// void VectorBase::AddTpVec(const Real alpha, const TpMatrix &M, +// const MatrixTransposeType trans, +// const VectorBase &v, +// const Real beta) { +// KALDI_ASSERT(dim_ == v.dim_ && dim_ == M.NumRows()); +// if (beta == 0.0) { +// if (&v != this) CopyFromVec(v); +// MulTp(M, trans); +// if (alpha != 1.0) Scale(alpha); +//} else { +// Vector tmp(v); +// tmp.MulTp(M, trans); +// if (beta != 1.0) Scale(beta); // *this <-- beta * *this +// AddVec(alpha, tmp); // *this += alpha * M * v +//} +//} + +// template +// Real VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2) { +// KALDI_ASSERT(v1.Dim() == M.NumRows() && v2.Dim() == M.NumCols()); +// Vector vtmp(M.NumRows()); +// vtmp.AddMatVec(1.0, M, kNoTrans, v2, 0.0); +// return VecVec(v1, vtmp); +//} + +// template +// float VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2); +// template +// double VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2); + +template +void Vector::Swap(Vector *other) { + std::swap(this->data_, other->data_); + std::swap(this->dim_, other->dim_); +} + + +// template +// void VectorBase::AddDiagMat2( +// Real alpha, const MatrixBase &M, +// MatrixTransposeType trans, Real beta) { +// if (trans == kNoTrans) { +// KALDI_ASSERT(this->dim_ == M.NumRows()); +// MatrixIndexT rows = this->dim_, cols = M.NumCols(), +// mat_stride = M.Stride(); +// Real *data = this->data_; +// const Real *mat_data = M.Data(); +// for (MatrixIndexT i = 0; i < rows; i++, mat_data += mat_stride, data++) +//*data = beta * *data + alpha * cblas_Xdot(cols,mat_data,1,mat_data,1); +//} else { +// KALDI_ASSERT(this->dim_ == M.NumCols()); +// MatrixIndexT rows = M.NumRows(), cols = this->dim_, +// mat_stride = M.Stride(); +// Real *data = this->data_; +// const Real *mat_data = M.Data(); +// for (MatrixIndexT i = 0; i < cols; i++, mat_data++, data++) +//*data = beta * *data + alpha * cblas_Xdot(rows, mat_data, mat_stride, +// mat_data, mat_stride); +//} +//} + +// template +// void VectorBase::AddDiagMatMat( +// Real alpha, +// const MatrixBase &M, MatrixTransposeType transM, +// const MatrixBase &N, MatrixTransposeType transN, +// Real beta) { +// MatrixIndexT dim = this->dim_, +// M_col_dim = (transM == kTrans ? M.NumRows() : M.NumCols()), +// N_row_dim = (transN == kTrans ? N.NumCols() : N.NumRows()); +// KALDI_ASSERT(M_col_dim == N_row_dim); // this is the dimension we sum over +// MatrixIndexT M_row_stride = M.Stride(), M_col_stride = 1; +// if (transM == kTrans) std::swap(M_row_stride, M_col_stride); +// MatrixIndexT N_row_stride = N.Stride(), N_col_stride = 1; +// if (transN == kTrans) std::swap(N_row_stride, N_col_stride); + +// Real *data = this->data_; +// const Real *Mdata = M.Data(), *Ndata = N.Data(); +// for (MatrixIndexT i = 0; i < dim; i++, Mdata += M_row_stride, Ndata += +// N_col_stride, data++) { +//*data = beta * *data + alpha * cblas_Xdot(M_col_dim, Mdata, M_col_stride, +// Ndata, N_row_stride); +//} +//} + + +template class Vector; +template class Vector; +template class VectorBase; +template class VectorBase; + +} // namespace kaldi diff --git a/runtime/engine/common/matrix/kaldi-vector.h b/runtime/engine/common/matrix/kaldi-vector.h new file mode 100644 index 00000000..461e026d --- /dev/null +++ b/runtime/engine/common/matrix/kaldi-vector.h @@ -0,0 +1,352 @@ +// matrix/kaldi-vector.h + +// Copyright 2009-2012 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University (Author: Arnab Ghoshal); +// Ariya Rastrow; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Arnab Ghoshal +// Wei Shi; +// 2015 Guoguo Chen +// 2017 Daniel Galvez +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ 1 + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + +/// Provides a vector abstraction class. +/// This class provides a way to work with vectors in kaldi. +/// It encapsulates basic operations and memory optimizations. +template +class VectorBase { + public: + /// Set vector to all zeros. + void SetZero(); + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number + + /// Set all members of a vector to a specified value. + void Set(Real f); + + /// Returns the dimension of the vector. + inline MatrixIndexT Dim() const { return dim_; } + + /// Returns the size in memory of the vector, in bytes. + inline MatrixIndexT SizeInBytes() const { return (dim_ * sizeof(Real)); } + + /// Returns a pointer to the start of the vector's data. + inline Real *Data() { return data_; } + + /// Returns a pointer to the start of the vector's data (const). + inline const Real *Data() const { return data_; } + + /// Indexing operator (const). + inline Real operator()(MatrixIndexT i) const { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /// Indexing operator (non-const). + inline Real &operator()(MatrixIndexT i) { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /** @brief Returns a sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector(*this, o, l); + } + + /** @brief Returns a const sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + const SubVector Range(const MatrixIndexT o, + const MatrixIndexT l) const { + return SubVector(*this, o, l); + } + + /// Copy data from another vector (must match own size). + void CopyFromVec(const VectorBase &v); + + /// Copy data from another vector of different type (double vs. float) + template + void CopyFromVec(const VectorBase &v); + + /// Performs a row stack of the matrix M + void CopyRowsFromMat(const MatrixBase &M); + template + void CopyRowsFromMat(const MatrixBase &M); + + /// Performs a column stack of the matrix M + void CopyColsFromMat(const MatrixBase &M); + + /// Extracts a row of the matrix M. Could also do this with + /// this->Copy(M[row]). + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + /// Extracts a row of the matrix M with type conversion. + template + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + + /// Extracts a column of the matrix M. + template + void CopyColFromMat(const MatrixBase &M, MatrixIndexT col); + + /// Reads from C++ stream (option to add to existing contents). + /// Throws exception on failure + void Read(std::istream &in, bool binary); + + /// Writes to C++ stream (option to write in binary). + void Write(std::ostream &Out, bool binary) const; + + friend class VectorBase; + friend class VectorBase; + + protected: + /// Destructor; does not deallocate memory, this is handled by child + /// classes. + /// This destructor is protected so this object can only be + /// deleted via a child. + ~VectorBase() {} + + /// Empty initializer, corresponds to vector of zero size. + explicit VectorBase() : data_(NULL), dim_(0) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// data memory area + Real *data_; + /// dimension of vector + MatrixIndexT dim_; + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; // class VectorBase + +/** @brief A class representing a vector. + * + * This class provides a way to work with vectors in kaldi. + * It encapsulates basic operations and memory optimizations. */ +template +class Vector : public VectorBase { + public: + /// Constructor that takes no arguments. Initializes to empty. + Vector() : VectorBase() {} + + /// Constructor with specific size. Sets to all-zero by default + /// if set_zero == false, memory contents are undefined. + explicit Vector(const MatrixIndexT s, + MatrixResizeType resize_type = kSetZero) + : VectorBase() { + Resize(s, resize_type); + } + + /// Copy constructor from CUDA vector + /// This is defined in ../cudamatrix/cu-vector.h + // template + // explicit Vector(const CuVectorBase &cu); + + /// Copy constructor. The need for this is controversial. + Vector(const Vector &v) + : VectorBase() { // (cannot be explicit) + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Copy-constructor from base-class, needed to copy from SubVector. + explicit Vector(const VectorBase &v) : VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Type conversion constructor. + template + explicit Vector(const VectorBase &v) : VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + // Took this out since it is unsafe : Arnab + // /// Constructor from a pointer and a size; copies the data to a location + // /// it owns. + // Vector(const Real* Data, const MatrixIndexT s): VectorBase() { + // Resize(s); + // CopyFromPtr(Data, s); + // } + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Vector *other); + + /// Destructor. Deallocates memory. + ~Vector() { Destroy(); } + + /// Read function using C++ streams. Can also add to existing contents + /// of matrix. + void Read(std::istream &in, bool binary); + + /// Set vector to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); + + /// Remove one element and shifts later elements down. + void RemoveElement(MatrixIndexT i); + + /// Assignment operator. + Vector &operator=(const Vector &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + /// Assignment operator that takes VectorBase. + Vector &operator=(const VectorBase &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + private: + /// Init assumes the current contents of the class are invalid (i.e. junk or + /// has already been freed), and it sets the vector to newly allocated + /// memory + /// with the specified dimension. dim == 0 is acceptable. The memory + /// contents + /// pointed to by data_ will be undefined. + void Init(const MatrixIndexT dim); + + /// Destroy function, called internally. + void Destroy(); +}; + + +/// Represents a non-allocating general vector which can be defined +/// as a sub-vector of higher-level vector [or as the row of a matrix]. +template +class SubVector : public VectorBase { + public: + /// Constructor from a Vector or SubVector. + /// SubVectors are not const-safe and it's very hard to make them + /// so for now we just give up. This function contains const_cast. + SubVector(const VectorBase &t, + const MatrixIndexT origin, + const MatrixIndexT length) + : VectorBase() { + // following assert equiv to origin>=0 && length>=0 && + // origin+length <= rt.dim_ + KALDI_ASSERT(static_cast(origin) + + static_cast(length) <= + static_cast(t.Dim())); + VectorBase::data_ = const_cast(t.Data() + origin); + VectorBase::dim_ = length; + } + + /// This constructor initializes the vector to point at the contents + /// of this packed matrix (SpMatrix or TpMatrix). + // SubVector(const PackedMatrix &M) { + // VectorBase::data_ = const_cast (M.Data()); + // VectorBase::dim_ = (M.NumRows()*(M.NumRows()+1))/2; + //} + + /// Copy constructor + SubVector(const SubVector &other) : VectorBase() { + // this copy constructor needed for Range() to work in base class. + VectorBase::data_ = other.data_; + VectorBase::dim_ = other.dim_; + } + + /// Constructor from a pointer to memory and a length. Keeps a pointer + /// to the data but does not take ownership (will never delete). + /// Caution: this constructor enables you to evade const constraints. + SubVector(const Real *data, MatrixIndexT length) : VectorBase() { + VectorBase::data_ = const_cast(data); + VectorBase::dim_ = length; + } + + /// This operation does not preserve const-ness, so be careful. + SubVector(const MatrixBase &matrix, MatrixIndexT row) { + VectorBase::data_ = const_cast(matrix.RowData(row)); + VectorBase::dim_ = matrix.NumCols(); + } + + ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). + + private: + /// Disallow assignment operator. + SubVector &operator=(const SubVector &other) {} +}; + +/// @} end of "addtogroup matrix_group" +/// \addtogroup matrix_funcs_io +/// @{ +/// Output to a C++ stream. Non-binary by default (use Write for +/// binary output). +template +std::ostream &operator<<(std::ostream &out, const VectorBase &v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template +std::istream &operator>>(std::istream &in, VectorBase &v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template +std::istream &operator>>(std::istream &in, Vector &v); +/// @} end of \addtogroup matrix_funcs_io + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +// template +// bool ApproxEqual(const VectorBase &a, +// const VectorBase &b, Real tol = 0.01) { +// return a.ApproxEqual(b, tol); +//} + +// template +// inline void AssertEqual(VectorBase &a, VectorBase &b, +// float tol = 0.01) { +// KALDI_ASSERT(a.ApproxEqual(b, tol)); +//} + + +} // namespace kaldi + +// we need to include the implementation +#include "matrix/kaldi-vector-inl.h" + + +#endif // KALDI_MATRIX_KALDI_VECTOR_H_ diff --git a/speechx/speechx/kaldi/matrix/matrix-common.h b/runtime/engine/common/matrix/matrix-common.h similarity index 50% rename from speechx/speechx/kaldi/matrix/matrix-common.h rename to runtime/engine/common/matrix/matrix-common.h index f7047d71..e915db0a 100644 --- a/speechx/speechx/kaldi/matrix/matrix-common.h +++ b/runtime/engine/common/matrix/matrix-common.h @@ -27,71 +27,58 @@ namespace kaldi { // this enums equal to CblasTrans and CblasNoTrans constants from CBLAS library -// we are writing them as literals because we don't want to include here matrix/kaldi-blas.h, -// which puts many symbols into global scope (like "real") via the header f2c.h +// we are writing them as literals because we don't want to include here +// matrix/kaldi-blas.h, +// which puts many symbols into global scope (like "real") via the header f2c.h typedef enum { - kTrans = 112, // = CblasTrans - kNoTrans = 111 // = CblasNoTrans + kTrans = 112, // = CblasTrans + kNoTrans = 111 // = CblasNoTrans } MatrixTransposeType; -typedef enum { - kSetZero, - kUndefined, - kCopyData -} MatrixResizeType; +typedef enum { kSetZero, kUndefined, kCopyData } MatrixResizeType; typedef enum { - kDefaultStride, - kStrideEqualNumCols, + kDefaultStride, + kStrideEqualNumCols, } MatrixStrideType; typedef enum { - kTakeLower, - kTakeUpper, - kTakeMean, - kTakeMeanAndCheck + kTakeLower, + kTakeUpper, + kTakeMean, + kTakeMeanAndCheck } SpCopyType; -template class VectorBase; -template class Vector; -template class SubVector; -template class MatrixBase; -template class SubMatrix; -template class Matrix; -template class SpMatrix; -template class TpMatrix; -template class PackedMatrix; -template class SparseMatrix; - -// these are classes that won't be defined in this -// directory; they're mostly needed for friend declarations. -template class CuMatrixBase; -template class CuSubMatrix; -template class CuMatrix; -template class CuVectorBase; -template class CuSubVector; -template class CuVector; -template class CuPackedMatrix; -template class CuSpMatrix; -template class CuTpMatrix; -template class CuSparseMatrix; - -class CompressedMatrix; -class GeneralMatrix; +template +class VectorBase; +template +class Vector; +template +class SubVector; +template +class MatrixBase; +template +class SubMatrix; +template +class Matrix; + /// This class provides a way for switching between double and float types. -template class OtherReal { }; // useful in reading+writing routines - // to switch double and float. +template +class OtherReal {}; // useful in reading+writing routines + // to switch double and float. /// A specialized class for switching from float to double. -template<> class OtherReal { - public: - typedef double Real; +template <> +class OtherReal { + public: + typedef double Real; }; /// A specialized class for switching from double to float. -template<> class OtherReal { - public: - typedef float Real; +template <> +class OtherReal { + public: + typedef float Real; }; @@ -100,12 +87,10 @@ typedef int32 SignedMatrixIndexT; typedef uint32 UnsignedMatrixIndexT; // If you want to use size_t for the index type, do as follows instead: -//typedef size_t MatrixIndexT; -//typedef ssize_t SignedMatrixIndexT; -//typedef size_t UnsignedMatrixIndexT; - -} - +// typedef size_t MatrixIndexT; +// typedef ssize_t SignedMatrixIndexT; +// typedef size_t UnsignedMatrixIndexT; +} // namespace kaldi #endif // KALDI_MATRIX_MATRIX_COMMON_H_ diff --git a/runtime/engine/common/utils/CMakeLists.txt b/runtime/engine/common/utils/CMakeLists.txt new file mode 100644 index 00000000..14733648 --- /dev/null +++ b/runtime/engine/common/utils/CMakeLists.txt @@ -0,0 +1,28 @@ + + +set(csrc + file_utils.cc + math.cc + strings.cc + audio_process.cc + timer.cc +) + +add_library(utils ${csrc}) + +if(WITH_TESTING) + enable_testing() + + if(ANDROID) + else() # UNIX + link_libraries(gtest_main gmock) + + add_executable(strings_test strings_test.cc) + target_link_libraries(strings_test PUBLIC utils) + add_test( + NAME strings_test + COMMAND strings_test + ) + endif() +endif() + diff --git a/runtime/engine/common/utils/audio_process.cc b/runtime/engine/common/utils/audio_process.cc new file mode 100644 index 00000000..54540b85 --- /dev/null +++ b/runtime/engine/common/utils/audio_process.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utils/audio_process.h" + +namespace ppspeech{ + +int WaveformFloatNormal(std::vector* waveform) { + int tot_samples = waveform->size(); + for (int i = 0; i < tot_samples; i++) { + (*waveform)[i] = (*waveform)[i] / 32768.0; + } + return 0; +} + +int WaveformNormal(std::vector* waveform, + bool wav_normal, + const std::string& wav_normal_type, + float wav_norm_mul_factor) { + if (wav_normal == false) { + return 0; + } + if (wav_normal_type == "linear") { + float amax = INT32_MIN; + for (int i = 0; i < waveform->size(); ++i) { + float tmp = std::abs((*waveform)[i]); + amax = std::max(amax, tmp); + } + float factor = 1.0 / (amax + 1e-8); + for (int i = 0; i < waveform->size(); ++i) { + (*waveform)[i] = (*waveform)[i] * factor * wav_norm_mul_factor; + } + } else if (wav_normal_type == "gaussian") { + double sum = std::accumulate(waveform->begin(), waveform->end(), 0.0); + double mean = sum / waveform->size(); //均值 + + double accum = 0.0; + std::for_each(waveform->begin(), waveform->end(), [&](const double d) { + accum += (d - mean) * (d - mean); + }); + + double stdev = sqrt(accum / (waveform->size() - 1)); //方差 + stdev = std::max(stdev, 1e-8); + + for (int i = 0; i < waveform->size(); ++i) { + (*waveform)[i] = + wav_norm_mul_factor * ((*waveform)[i] - mean) / stdev; + } + } else { + printf("don't support\n"); + return -1; + } + return 0; +} + +float PowerTodb(float in, float ref_value, float amin, float top_db) { + if (amin <= 0) { + printf("amin must be strictly positive\n"); + return -1; + } + + if (ref_value <= 0) { + printf("ref_value must be strictly positive\n"); + return -1; + } + + float out = 10.0 * log10(std::max(amin, in)); + out -= 10.0 * log10(std::max(ref_value, amin)); + return out; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/utils/audio_process.h b/runtime/engine/common/utils/audio_process.h new file mode 100644 index 00000000..164d4c07 --- /dev/null +++ b/runtime/engine/common/utils/audio_process.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +namespace ppspeech{ +int WaveformFloatNormal(std::vector* waveform); +int WaveformNormal(std::vector* waveform, + bool wav_normal, + const std::string& wav_normal_type, + float wav_norm_mul_factor); +float PowerTodb(float in, + float ref_value = 1.0, + float amin = 1e-10, + float top_db = 80.0); +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/utils/blank_process_test.cc b/runtime/engine/common/utils/blank_process_test.cc new file mode 100644 index 00000000..75f762ae --- /dev/null +++ b/runtime/engine/common/utils/blank_process_test.cc @@ -0,0 +1,26 @@ +#include "utils/blank_process.h" + +#include +#include + +TEST(BlankProcess, BlankProcessTest) { + std::string test_str = "我 今天 去 了 超市 花了 120 元。"; + std::string out_str = ppspeech::BlankProcess(test_str); + int ret = out_str.compare("我今天去了超市花了120元。"); + EXPECT_EQ(ret, 0); + + test_str = "how are you today"; + out_str = ppspeech::BlankProcess(test_str); + ret = out_str.compare("how are you today"); + EXPECT_EQ(ret, 0); + + test_str = "我 的 paper 在 哪里?"; + out_str = ppspeech::BlankProcess(test_str); + ret = out_str.compare("我的paper在哪里?"); + EXPECT_EQ(ret, 0); + + test_str = "我 今天 去 了 超市 花了 120 元。"; + out_str = ppspeech::BlankProcess(test_str); + ret = out_str.compare("我今天去了超市花了120元。"); + EXPECT_EQ(ret, 0); +} \ No newline at end of file diff --git a/speechx/speechx/utils/file_utils.cc b/runtime/engine/common/utils/file_utils.cc similarity index 61% rename from speechx/speechx/utils/file_utils.cc rename to runtime/engine/common/utils/file_utils.cc index c42a642c..385f2b65 100644 --- a/speechx/speechx/utils/file_utils.cc +++ b/runtime/engine/common/utils/file_utils.cc @@ -14,6 +14,8 @@ #include "utils/file_utils.h" +#include + namespace ppspeech { bool ReadFileToVector(const std::string& filename, @@ -40,4 +42,31 @@ std::string ReadFile2String(const std::string& path) { return std::string((std::istreambuf_iterator(input_file)), std::istreambuf_iterator()); } + +bool FileExists(const std::string& strFilename) { + // this funciton if from: + // https://github.com/kaldi-asr/kaldi/blob/master/src/fstext/deterministic-fst-test.cc + struct stat stFileInfo; + bool blnReturn; + int intStat; + + // Attempt to get the file attributes + intStat = stat(strFilename.c_str(), &stFileInfo); + if (intStat == 0) { + // We were able to get the file attributes + // so the file obviously exists. + blnReturn = true; + } else { + // We were not able to get the file attributes. + // This may mean that we don't have permission to + // access the folder which contains this file. If you + // need to do that level of checking, lookup the + // return values of stat which will give you + // more details on why stat failed. + blnReturn = false; + } + + return blnReturn; +} + } // namespace ppspeech diff --git a/speechx/speechx/utils/file_utils.h b/runtime/engine/common/utils/file_utils.h similarity index 94% rename from speechx/speechx/utils/file_utils.h rename to runtime/engine/common/utils/file_utils.h index a471e024..420740db 100644 --- a/speechx/speechx/utils/file_utils.h +++ b/runtime/engine/common/utils/file_utils.h @@ -20,4 +20,7 @@ bool ReadFileToVector(const std::string& filename, std::vector* data); std::string ReadFile2String(const std::string& path); + +bool FileExists(const std::string& filename); + } // namespace ppspeech diff --git a/speechx/speechx/utils/math.cc b/runtime/engine/common/utils/math.cc similarity index 97% rename from speechx/speechx/utils/math.cc rename to runtime/engine/common/utils/math.cc index 71656cb3..1f0c9c93 100644 --- a/speechx/speechx/utils/math.cc +++ b/runtime/engine/common/utils/math.cc @@ -15,13 +15,14 @@ // limitations under the License. #include "utils/math.h" +#include "base/basic_types.h" #include #include #include +#include #include - -#include "base/common.h" +#include namespace ppspeech { diff --git a/speechx/speechx/utils/math.h b/runtime/engine/common/utils/math.h similarity index 100% rename from speechx/speechx/utils/math.h rename to runtime/engine/common/utils/math.h diff --git a/runtime/engine/common/utils/picojson.h b/runtime/engine/common/utils/picojson.h new file mode 100644 index 00000000..2ac265f5 --- /dev/null +++ b/runtime/engine/common/utils/picojson.h @@ -0,0 +1,1230 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011-2014 Kazuho Oku + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef picojson_h +#define picojson_h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define PICOJSON_USE_INT64 1 + +// for isnan/isinf +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#ifdef _MSC_VER +#include +#elif defined(__INTEL_COMPILER) +#include +#else +#include +#endif +} +#endif + +#ifndef PICOJSON_USE_RVALUE_REFERENCE +#if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || \ + (defined(_MSC_VER) && _MSC_VER >= 1600) +#define PICOJSON_USE_RVALUE_REFERENCE 1 +#else +#define PICOJSON_USE_RVALUE_REFERENCE 0 +#endif +#endif // PICOJSON_USE_RVALUE_REFERENCE + +#ifndef PICOJSON_NOEXCEPT +#if PICOJSON_USE_RVALUE_REFERENCE +#define PICOJSON_NOEXCEPT noexcept +#else +#define PICOJSON_NOEXCEPT throw() +#endif +#endif + +// experimental support for int64_t (see README.mkdn for detail) +#ifdef PICOJSON_USE_INT64 +#define __STDC_FORMAT_MACROS +#include +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#include +} +#endif +#endif + +// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0 +#ifndef PICOJSON_USE_LOCALE +#define PICOJSON_USE_LOCALE 1 +#endif +#if PICOJSON_USE_LOCALE +extern "C" { +#include +} +#endif + +#ifndef PICOJSON_ASSERT +#define PICOJSON_ASSERT(e) \ + do { \ + if (!(e)) throw std::runtime_error(#e); \ + } while (0) +#endif + +#ifdef _MSC_VER +#define SNPRINTF _snprintf_s +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from int to char +#pragma warning(disable : 4127) // conditional expression is constant +#pragma warning(disable : 4702) // unreachable code +#pragma warning(disable : 4706) // assignment within conditional expression +#else +#define SNPRINTF snprintf +#endif + +namespace picojson { + +enum { + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type +#ifdef PICOJSON_USE_INT64 + , + int64_type +#endif +}; + +enum { INDENT_WIDTH = 2, DEFAULT_MAX_DEPTHS = 100 }; + +struct null {}; + +class value { + public: + typedef std::vector array; + typedef std::map object; + union _storage { + bool boolean_; + double number_; +#ifdef PICOJSON_USE_INT64 + int64_t int64_; +#endif + std::string *string_; + array *array_; + object *object_; + }; + + protected: + int type_; + _storage u_; + + public: + value(); + value(int type, bool); + explicit value(bool b); +#ifdef PICOJSON_USE_INT64 + explicit value(int64_t i); +#endif + explicit value(double n); + explicit value(const std::string &s); + explicit value(const array &a); + explicit value(const object &o); +#if PICOJSON_USE_RVALUE_REFERENCE + explicit value(std::string &&s); + explicit value(array &&a); + explicit value(object &&o); +#endif + explicit value(const char *s); + value(const char *s, size_t len); + ~value(); + value(const value &x); + value &operator=(const value &x); +#if PICOJSON_USE_RVALUE_REFERENCE + value(value &&x) PICOJSON_NOEXCEPT; + value &operator=(value &&x) PICOJSON_NOEXCEPT; +#endif + void swap(value &x) PICOJSON_NOEXCEPT; + template + bool is() const; + template + const T &get() const; + template + T &get(); + template + void set(const T &); +#if PICOJSON_USE_RVALUE_REFERENCE + template + void set(T &&); +#endif + bool evaluate_as_boolean() const; + const value &get(const size_t idx) const; + const value &get(const std::string &key) const; + value &get(const size_t idx); + value &get(const std::string &key); + + bool contains(const size_t idx) const; + bool contains(const std::string &key) const; + std::string to_str() const; + template + void serialize(Iter os, bool prettify = false) const; + std::string serialize(bool prettify = false) const; + + private: + template + value(const T *); // intentionally defined to block implicit conversion of + // pointer to bool + template + static void _indent(Iter os, int indent); + template + void _serialize(Iter os, int indent) const; + std::string _serialize(int indent) const; + void clear(); +}; + +typedef value::array array; +typedef value::object object; + +inline value::value() : type_(null_type), u_() {} + +inline value::value(int type, bool) : type_(type), u_() { + switch (type) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(boolean_, false); + INIT(number_, 0.0); +#ifdef PICOJSON_USE_INT64 + INIT(int64_, 0); +#endif + INIT(string_, new std::string()); + INIT(array_, new array()); + INIT(object_, new object()); +#undef INIT + default: + break; + } +} + +inline value::value(bool b) : type_(boolean_type), u_() { u_.boolean_ = b; } + +#ifdef PICOJSON_USE_INT64 +inline value::value(int64_t i) : type_(int64_type), u_() { u_.int64_ = i; } +#endif + +inline value::value(double n) : type_(number_type), u_() { + if ( +#ifdef _MSC_VER + !_finite(n) +#elif __cplusplus >= 201103L + std::isnan(n) || std::isinf(n) +#else + isnan(n) || isinf(n) +#endif + ) { + throw std::overflow_error(""); + } + u_.number_ = n; +} + +inline value::value(const std::string &s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const array &a) : type_(array_type), u_() { + u_.array_ = new array(a); +} + +inline value::value(const object &o) : type_(object_type), u_() { + u_.object_ = new object(o); +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(std::string &&s) : type_(string_type), u_() { + u_.string_ = new std::string(std::move(s)); +} + +inline value::value(array &&a) : type_(array_type), u_() { + u_.array_ = new array(std::move(a)); +} + +inline value::value(object &&o) : type_(object_type), u_() { + u_.object_ = new object(std::move(o)); +} +#endif + +inline value::value(const char *s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const char *s, size_t len) : type_(string_type), u_() { + u_.string_ = new std::string(s, len); +} + +inline void value::clear() { + switch (type_) { +#define DEINIT(p) \ + case p##type: \ + delete u_.p; \ + break + DEINIT(string_); + DEINIT(array_); + DEINIT(object_); +#undef DEINIT + default: + break; + } +} + +inline value::~value() { clear(); } + +inline value::value(const value &x) : type_(x.type_), u_() { + switch (type_) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(string_, new std::string(*x.u_.string_)); + INIT(array_, new array(*x.u_.array_)); + INIT(object_, new object(*x.u_.object_)); +#undef INIT + default: + u_ = x.u_; + break; + } +} + +inline value &value::operator=(const value &x) { + if (this != &x) { + value t(x); + swap(t); + } + return *this; +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(value &&x) PICOJSON_NOEXCEPT : type_(null_type), u_() { + swap(x); +} +inline value &value::operator=(value &&x) PICOJSON_NOEXCEPT { + swap(x); + return *this; +} +#endif +inline void value::swap(value &x) PICOJSON_NOEXCEPT { + std::swap(type_, x.type_); + std::swap(u_, x.u_); +} + +#define IS(ctype, jtype) \ + template <> \ + inline bool value::is() const { \ + return type_ == jtype##_type; \ + } +IS(null, null) +IS(bool, boolean) +#ifdef PICOJSON_USE_INT64 +IS(int64_t, int64) +#endif +IS(std::string, string) +IS(array, array) +IS(object, object) +#undef IS +template <> +inline bool value::is() const { + return type_ == number_type +#ifdef PICOJSON_USE_INT64 + || type_ == int64_type +#endif + ; +} + +#define GET(ctype, var) \ + template <> \ + inline const ctype &value::get() const { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && \ + is()); \ + return var; \ + } \ + template <> \ + inline ctype &value::get() { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && \ + is()); \ + return var; \ + } +GET(bool, u_.boolean_) +GET(std::string, *u_.string_) +GET(array, *u_.array_) +GET(object, *u_.object_) +#ifdef PICOJSON_USE_INT64 +GET(double, + (type_ == int64_type && + (const_cast(this)->type_ = number_type, + (const_cast(this)->u_.number_ = u_.int64_)), + u_.number_)) +GET(int64_t, u_.int64_) +#else +GET(double, u_.number_) +#endif +#undef GET + +#define SET(ctype, jtype, setter) \ + template <> \ + inline void value::set(const ctype &_val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +SET(bool, boolean, u_.boolean_ = _val;) +SET(std::string, string, u_.string_ = new std::string(_val);) +SET(array, array, u_.array_ = new array(_val);) +SET(object, object, u_.object_ = new object(_val);) +SET(double, number, u_.number_ = _val;) +#ifdef PICOJSON_USE_INT64 +SET(int64_t, int64, u_.int64_ = _val;) +#endif +#undef SET + +#if PICOJSON_USE_RVALUE_REFERENCE +#define MOVESET(ctype, jtype, setter) \ + template <> \ + inline void value::set(ctype && _val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));) +MOVESET(array, array, u_.array_ = new array(std::move(_val));) +MOVESET(object, object, u_.object_ = new object(std::move(_val));) +#undef MOVESET +#endif + +inline bool value::evaluate_as_boolean() const { + switch (type_) { + case null_type: + return false; + case boolean_type: + return u_.boolean_; + case number_type: + return u_.number_ != 0; +#ifdef PICOJSON_USE_INT64 + case int64_type: + return u_.int64_ != 0; +#endif + case string_type: + return !u_.string_->empty(); + default: + return true; + } +} + +inline const value &value::get(const size_t idx) const { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline value &value::get(const size_t idx) { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline const value &value::get(const std::string &key) const { + static value s_null; + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline value &value::get(const std::string &key) { + static value s_null; + PICOJSON_ASSERT(is()); + object::iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline bool value::contains(const size_t idx) const { + PICOJSON_ASSERT(is()); + return idx < u_.array_->size(); +} + +inline bool value::contains(const std::string &key) const { + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end(); +} + +inline std::string value::to_str() const { + switch (type_) { + case null_type: + return "null"; + case boolean_type: + return u_.boolean_ ? "true" : "false"; +#ifdef PICOJSON_USE_INT64 + case int64_type: { + char buf[sizeof("-9223372036854775808")]; + SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); + return buf; + } +#endif + case number_type: { + char buf[256]; + double tmp; + SNPRINTF( + buf, + sizeof(buf), + fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 + ? "%.f" + : "%.17g", + u_.number_); +#if PICOJSON_USE_LOCALE + char *decimal_point = localeconv()->decimal_point; + if (strcmp(decimal_point, ".") != 0) { + size_t decimal_point_len = strlen(decimal_point); + for (char *p = buf; *p != '\0'; ++p) { + if (strncmp(p, decimal_point, decimal_point_len) == 0) { + return std::string(buf, p) + "." + + (p + decimal_point_len); + } + } + } +#endif + return buf; + } + case string_type: + return *u_.string_; + case array_type: + return "array"; + case object_type: + return "object"; + default: + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + } + return std::string(); +} + +template +void copy(const std::string &s, Iter oi) { + std::copy(s.begin(), s.end(), oi); +} + +template +struct serialize_str_char { + Iter oi; + void operator()(char c) { + switch (c) { +#define MAP(val, sym) \ + case val: \ + copy(sym, oi); \ + break + MAP('"', "\\\""); + MAP('\\', "\\\\"); + MAP('/', "\\/"); + MAP('\b', "\\b"); + MAP('\f', "\\f"); + MAP('\n', "\\n"); + MAP('\r', "\\r"); + MAP('\t', "\\t"); +#undef MAP + default: + if (static_cast(c) < 0x20 || c == 0x7f) { + char buf[7]; + SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff); + copy(buf, buf + 6, oi); + } else { + *oi++ = c; + } + break; + } + } +}; + +template +void serialize_str(const std::string &s, Iter oi) { + *oi++ = '"'; + serialize_str_char process_char = {oi}; + std::for_each(s.begin(), s.end(), process_char); + *oi++ = '"'; +} + +template +void value::serialize(Iter oi, bool prettify) const { + return _serialize(oi, prettify ? 0 : -1); +} + +inline std::string value::serialize(bool prettify) const { + return _serialize(prettify ? 0 : -1); +} + +template +void value::_indent(Iter oi, int indent) { + *oi++ = '\n'; + for (int i = 0; i < indent * INDENT_WIDTH; ++i) { + *oi++ = ' '; + } +} + +template +void value::_serialize(Iter oi, int indent) const { + switch (type_) { + case string_type: + serialize_str(*u_.string_, oi); + break; + case array_type: { + *oi++ = '['; + if (indent != -1) { + ++indent; + } + for (array::const_iterator i = u_.array_->begin(); + i != u_.array_->end(); + ++i) { + if (i != u_.array_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + i->_serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.array_->empty()) { + _indent(oi, indent); + } + } + *oi++ = ']'; + break; + } + case object_type: { + *oi++ = '{'; + if (indent != -1) { + ++indent; + } + for (object::const_iterator i = u_.object_->begin(); + i != u_.object_->end(); + ++i) { + if (i != u_.object_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + serialize_str(i->first, oi); + *oi++ = ':'; + if (indent != -1) { + *oi++ = ' '; + } + i->second._serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.object_->empty()) { + _indent(oi, indent); + } + } + *oi++ = '}'; + break; + } + default: + copy(to_str(), oi); + break; + } + if (indent == 0) { + *oi++ = '\n'; + } +} + +inline std::string value::_serialize(int indent) const { + std::string s; + _serialize(std::back_inserter(s), indent); + return s; +} + +template +class input { + protected: + Iter cur_, end_; + bool consumed_; + int line_; + + public: + input(const Iter &first, const Iter &last) + : cur_(first), end_(last), consumed_(false), line_(1) {} + int getc() { + if (consumed_) { + if (*cur_ == '\n') { + ++line_; + } + ++cur_; + } + if (cur_ == end_) { + consumed_ = false; + return -1; + } + consumed_ = true; + return *cur_ & 0xff; + } + void ungetc() { consumed_ = false; } + Iter cur() const { + if (consumed_) { + input *self = const_cast *>(this); + self->consumed_ = false; + ++self->cur_; + } + return cur_; + } + int line() const { return line_; } + void skip_ws() { + while (1) { + int ch = getc(); + if (!(ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { + ungetc(); + break; + } + } + } + bool expect(const int expected) { + skip_ws(); + if (getc() != expected) { + ungetc(); + return false; + } + return true; + } + bool match(const std::string &pattern) { + for (std::string::const_iterator pi(pattern.begin()); + pi != pattern.end(); + ++pi) { + if (getc() != *pi) { + ungetc(); + return false; + } + } + return true; + } +}; + +template +inline int _parse_quadhex(input &in) { + int uni_ch = 0, hex; + for (int i = 0; i < 4; i++) { + if ((hex = in.getc()) == -1) { + return -1; + } + if ('0' <= hex && hex <= '9') { + hex -= '0'; + } else if ('A' <= hex && hex <= 'F') { + hex -= 'A' - 0xa; + } else if ('a' <= hex && hex <= 'f') { + hex -= 'a' - 0xa; + } else { + in.ungetc(); + return -1; + } + uni_ch = uni_ch * 16 + hex; + } + return uni_ch; +} + +template +inline bool _parse_codepoint(String &out, input &in) { + int uni_ch; + if ((uni_ch = _parse_quadhex(in)) == -1) { + return false; + } + if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { + if (0xdc00 <= uni_ch) { + // a second 16-bit of a surrogate pair appeared + return false; + } + // first 16-bit of surrogate pair, get the next one + if (in.getc() != '\\' || in.getc() != 'u') { + in.ungetc(); + return false; + } + int second = _parse_quadhex(in); + if (!(0xdc00 <= second && second <= 0xdfff)) { + return false; + } + uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); + uni_ch += 0x10000; + } + if (uni_ch < 0x80) { + out.push_back(static_cast(uni_ch)); + } else { + if (uni_ch < 0x800) { + out.push_back(static_cast(0xc0 | (uni_ch >> 6))); + } else { + if (uni_ch < 0x10000) { + out.push_back(static_cast(0xe0 | (uni_ch >> 12))); + } else { + out.push_back(static_cast(0xf0 | (uni_ch >> 18))); + out.push_back( + static_cast(0x80 | ((uni_ch >> 12) & 0x3f))); + } + out.push_back(static_cast(0x80 | ((uni_ch >> 6) & 0x3f))); + } + out.push_back(static_cast(0x80 | (uni_ch & 0x3f))); + } + return true; +} + +template +inline bool _parse_string(String &out, input &in) { + while (1) { + int ch = in.getc(); + if (ch < ' ') { + in.ungetc(); + return false; + } else if (ch == '"') { + return true; + } else if (ch == '\\') { + if ((ch = in.getc()) == -1) { + return false; + } + switch (ch) { +#define MAP(sym, val) \ + case sym: \ + out.push_back(val); \ + break + MAP('"', '\"'); + MAP('\\', '\\'); + MAP('/', '/'); + MAP('b', '\b'); + MAP('f', '\f'); + MAP('n', '\n'); + MAP('r', '\r'); + MAP('t', '\t'); +#undef MAP + case 'u': + if (!_parse_codepoint(out, in)) { + return false; + } + break; + default: + return false; + } + } else { + out.push_back(static_cast(ch)); + } + } + return false; +} + +template +inline bool _parse_array(Context &ctx, input &in) { + if (!ctx.parse_array_start()) { + return false; + } + size_t idx = 0; + if (in.expect(']')) { + return ctx.parse_array_stop(idx); + } + do { + if (!ctx.parse_array_item(in, idx)) { + return false; + } + idx++; + } while (in.expect(',')); + return in.expect(']') && ctx.parse_array_stop(idx); +} + +template +inline bool _parse_object(Context &ctx, input &in) { + if (!ctx.parse_object_start()) { + return false; + } + if (in.expect('}')) { + return ctx.parse_object_stop(); + } + do { + std::string key; + if (!in.expect('"') || !_parse_string(key, in) || !in.expect(':')) { + return false; + } + if (!ctx.parse_object_item(in, key)) { + return false; + } + } while (in.expect(',')); + return in.expect('}') && ctx.parse_object_stop(); +} + +template +inline std::string _parse_number(input &in) { + std::string num_str; + while (1) { + int ch = in.getc(); + if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == 'e' || + ch == 'E') { + num_str.push_back(static_cast(ch)); + } else if (ch == '.') { +#if PICOJSON_USE_LOCALE + num_str += localeconv()->decimal_point; +#else + num_str.push_back('.'); +#endif + } else { + in.ungetc(); + break; + } + } + return num_str; +} + +template +inline bool _parse(Context &ctx, input &in) { + in.skip_ws(); + int ch = in.getc(); + switch (ch) { +#define IS(ch, text, op) \ + case ch: \ + if (in.match(text) && op) { \ + return true; \ + } else { \ + return false; \ + } + IS('n', "ull", ctx.set_null()); + IS('f', "alse", ctx.set_bool(false)); + IS('t', "rue", ctx.set_bool(true)); +#undef IS + case '"': + return ctx.parse_string(in); + case '[': + return _parse_array(ctx, in); + case '{': + return _parse_object(ctx, in); + default: + if (('0' <= ch && ch <= '9') || ch == '-') { + double f; + char *endp; + in.ungetc(); + std::string num_str(_parse_number(in)); + if (num_str.empty()) { + return false; + } +#ifdef PICOJSON_USE_INT64 + { + errno = 0; + intmax_t ival = strtoimax(num_str.c_str(), &endp, 10); + if (errno == 0 && + std::numeric_limits::min() <= ival && + ival <= std::numeric_limits::max() && + endp == num_str.c_str() + num_str.size()) { + ctx.set_int64(ival); + return true; + } + } +#endif + f = strtod(num_str.c_str(), &endp); + if (endp == num_str.c_str() + num_str.size()) { + ctx.set_number(f); + return true; + } + return false; + } + break; + } + in.ungetc(); + return false; +} + +class deny_parse_context { + public: + bool set_null() { return false; } + bool set_bool(bool) { return false; } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { return false; } +#endif + bool set_number(double) { return false; } + template + bool parse_string(input &) { + return false; + } + bool parse_array_start() { return false; } + template + bool parse_array_item(input &, size_t) { + return false; + } + bool parse_array_stop(size_t) { return false; } + bool parse_object_start() { return false; } + template + bool parse_object_item(input &, const std::string &) { + return false; + } +}; + +class default_parse_context { + protected: + value *out_; + size_t depths_; + + public: + default_parse_context(value *out, size_t depths = DEFAULT_MAX_DEPTHS) + : out_(out), depths_(depths) {} + bool set_null() { + *out_ = value(); + return true; + } + bool set_bool(bool b) { + *out_ = value(b); + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t i) { + *out_ = value(i); + return true; + } +#endif + bool set_number(double f) { + *out_ = value(f); + return true; + } + template + bool parse_string(input &in) { + *out_ = value(string_type, false); + return _parse_string(out_->get(), in); + } + bool parse_array_start() { + if (depths_ == 0) return false; + --depths_; + *out_ = value(array_type, false); + return true; + } + template + bool parse_array_item(input &in, size_t) { + array &a = out_->get(); + a.push_back(value()); + default_parse_context ctx(&a.back(), depths_); + return _parse(ctx, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) return false; + *out_ = value(object_type, false); + return true; + } + template + bool parse_object_item(input &in, const std::string &key) { + object &o = out_->get(); + default_parse_context ctx(&o[key], depths_); + return _parse(ctx, in); + } + bool parse_object_stop() { + ++depths_; + return true; + } + + private: + default_parse_context(const default_parse_context &); + default_parse_context &operator=(const default_parse_context &); +}; + +class null_parse_context { + protected: + size_t depths_; + + public: + struct dummy_str { + void push_back(int) {} + }; + + public: + null_parse_context(size_t depths = DEFAULT_MAX_DEPTHS) : depths_(depths) {} + bool set_null() { return true; } + bool set_bool(bool) { return true; } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { return true; } +#endif + bool set_number(double) { return true; } + template + bool parse_string(input &in) { + dummy_str s; + return _parse_string(s, in); + } + bool parse_array_start() { + if (depths_ == 0) return false; + --depths_; + return true; + } + template + bool parse_array_item(input &in, size_t) { + return _parse(*this, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) return false; + --depths_; + return true; + } + template + bool parse_object_item(input &in, const std::string &) { + ++depths_; + return _parse(*this, in); + } + bool parse_object_stop() { return true; } + + private: + null_parse_context(const null_parse_context &); + null_parse_context &operator=(const null_parse_context &); +}; + +// obsolete, use the version below +template +inline std::string parse(value &out, Iter &pos, const Iter &last) { + std::string err; + pos = parse(out, pos, last, &err); + return err; +} + +template +inline Iter _parse(Context &ctx, + const Iter &first, + const Iter &last, + std::string *err) { + input in(first, last); + if (!_parse(ctx, in) && err != NULL) { + char buf[64]; + SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); + *err = buf; + while (1) { + int ch = in.getc(); + if (ch == -1 || ch == '\n') { + break; + } else if (ch >= ' ') { + err->push_back(static_cast(ch)); + } + } + } + return in.cur(); +} + +template +inline Iter parse(value &out, + const Iter &first, + const Iter &last, + std::string *err) { + default_parse_context ctx(&out); + return _parse(ctx, first, last, err); +} + +inline std::string parse(value &out, const std::string &s) { + std::string err; + parse(out, s.begin(), s.end(), &err); + return err; +} + +inline std::string parse(value &out, std::istream &is) { + std::string err; + parse(out, + std::istreambuf_iterator(is.rdbuf()), + std::istreambuf_iterator(), + &err); + return err; +} + +template +struct last_error_t { + static std::string s; +}; +template +std::string last_error_t::s; + +inline void set_last_error(const std::string &s) { last_error_t::s = s; } + +inline const std::string &get_last_error() { return last_error_t::s; } + +inline bool operator==(const value &x, const value &y) { + if (x.is()) return y.is(); +#define PICOJSON_CMP(type) \ + if (x.is()) return y.is() && x.get() == y.get() + PICOJSON_CMP(bool); + PICOJSON_CMP(double); + PICOJSON_CMP(std::string); + PICOJSON_CMP(array); + PICOJSON_CMP(object); +#undef PICOJSON_CMP + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + return false; +} + +inline bool operator!=(const value &x, const value &y) { return !(x == y); } +} + +#if !PICOJSON_USE_RVALUE_REFERENCE +namespace std { +template <> +inline void swap(picojson::value &x, picojson::value &y) { + x.swap(y); +} +} +#endif + +inline std::istream &operator>>(std::istream &is, picojson::value &x) { + picojson::set_last_error(std::string()); + const std::string err(picojson::parse(x, is)); + if (!err.empty()) { + picojson::set_last_error(err); + is.setstate(std::ios::failbit); + } + return is; +} + +inline std::ostream &operator<<(std::ostream &os, const picojson::value &x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif \ No newline at end of file diff --git a/runtime/engine/common/utils/strings.cc b/runtime/engine/common/utils/strings.cc new file mode 100644 index 00000000..91954d64 --- /dev/null +++ b/runtime/engine/common/utils/strings.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "utils/strings.h" + +namespace ppspeech { + +std::vector StrSplit(const std::string& str, + const char* delim, + bool omit_empty_string) { + std::vector outs; + int start = 0; + int end = str.size(); + int found = 0; + while (found != std::string::npos) { + found = str.find_first_of(delim, start); + // start != end condition is for when the delimiter is at the end + if (!omit_empty_string || (found != start && start != end)) { + outs.push_back(str.substr(start, found - start)); + } + start = found + 1; + } + + return outs; +} + + +std::string StrJoin(const std::vector& strs, const char* delim) { + std::stringstream ss; + for (ssize_t i = 0; i < strs.size(); ++i) { + ss << strs[i]; + if (i < strs.size() - 1) { + ss << std::string(delim); + } + } + return ss.str(); +} + +std::string DelBlank(const std::string& str) { + std::string out = ""; + int ptr_in = 0; // the pointer of input string (for traversal) + int end = str.size(); + int ptr_out = -1; // the pointer of output string (last char) + while (ptr_in != end) { + while (ptr_in != end && str[ptr_in] == ' ') { + ptr_in += 1; + } + if (ptr_in == end) + return out; + if (ptr_out != -1 && isalpha(str[ptr_in]) && isalpha(str[ptr_out]) && str[ptr_in-1] == ' ') + // add a space when the last and current chars are in English and there have space(s) between them + out += ' '; + out += str[ptr_in]; + ptr_out = ptr_in; + ptr_in += 1; + } + return out; +} + +std::string AddBlank(const std::string& str) { + std::string out = ""; + int ptr = 0; // the pointer of the input string + int end = str.size(); + while (ptr != end) { + if (isalpha(str[ptr])) { + if (ptr == 0 or str[ptr-1] != ' ') + out += " "; // add pre-space for an English word + while (isalpha(str[ptr])) { + out += str[ptr]; + ptr += 1; + } + out += " "; // add post-space for an English word + } else { + out += str[ptr]; + ptr += 1; + } + } + return out; +} + +std::string ReverseFraction(const std::string& str) { + std::string out = ""; + int ptr = 0; // the pointer of the input string + int end = str.size(); + int left, right, frac; // the start index of the left tag, right tag and '/'. + left = right = frac = 0; + int len_tag = 5; // length of "" + + while (ptr != end) { + // find the position of left tag, right tag and '/'. (xxxnum1/num2) + left = str.find("", ptr); + if (left == -1) + break; + out += str.substr(ptr, left - ptr); // content before left tag (xxx) + frac = str.find("/", left); + right = str.find("", frac); + + out += str.substr(frac + 1, right - frac - 1) + '/' + + str.substr(left + len_tag, frac - left - len_tag); // num2/num1 + ptr = right + len_tag; + } + if (ptr != end) { + out += str.substr(ptr, end - ptr); + } + return out; +} + +#ifdef _MSC_VER +std::wstring ToWString(const std::string& str) { + unsigned len = str.size() * 2; + setlocale(LC_CTYPE, ""); + wchar_t* p = new wchar_t[len]; + mbstowcs(p, str.c_str(), len); + std::wstring wstr(p); + delete[] p; + return wstr; +} +#endif + +} // namespace ppspeech diff --git a/runtime/engine/common/utils/strings.h b/runtime/engine/common/utils/strings.h new file mode 100644 index 00000000..cd79ae4f --- /dev/null +++ b/runtime/engine/common/utils/strings.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ppspeech { + +std::vector StrSplit(const std::string& str, + const char* delim, + bool omit_empty_string = true); + +std::string StrJoin(const std::vector& strs, const char* delim); + +std::string DelBlank(const std::string& str); + +std::string AddBlank(const std::string& str); + +std::string ReverseFraction(const std::string& str); + +#ifdef _MSC_VER +std::wstring ToWString(const std::string& str); +#endif + +} // namespace ppspeech diff --git a/runtime/engine/common/utils/strings_test.cc b/runtime/engine/common/utils/strings_test.cc new file mode 100644 index 00000000..058b6a01 --- /dev/null +++ b/runtime/engine/common/utils/strings_test.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "utils/strings.h" + +#include +#include + + +TEST(StringTest, StrSplitTest) { + using ::testing::ElementsAre; + + std::string test_str = "hello world"; + std::vector outs = ppspeech::StrSplit(test_str, " \t"); + EXPECT_THAT(outs, ElementsAre("hello", "world")); +} + + +TEST(StringTest, StrJoinTest) { + std::vector ins{"hello", "world"}; + std::string out = ppspeech::StrJoin(ins, " "); + EXPECT_THAT(out, "hello world"); +} + +TEST(StringText, DelBlankTest) { + std::string test_str = "我 今天 去 了 超市 花了 120 元。"; + std::string out_str = ppspeech::DelBlank(test_str); + int ret = out_str.compare("我今天去了超市花了120元。"); + EXPECT_EQ(ret, 0); + + test_str = "how are you today"; + out_str = ppspeech::DelBlank(test_str); + ret = out_str.compare("how are you today"); + EXPECT_EQ(ret, 0); + + test_str = "我 的 paper 在 哪里?"; + out_str = ppspeech::DelBlank(test_str); + ret = out_str.compare("我的paper在哪里?"); + EXPECT_EQ(ret, 0); +} + +TEST(StringTest, AddBlankTest) { + std::string test_str = "how are you"; + std::string out_str = ppspeech::AddBlank(test_str); + int ret = out_str.compare(" how are you "); + EXPECT_EQ(ret, 0); + + test_str = "欢迎来到China。"; + out_str = ppspeech::AddBlank(test_str); + ret = out_str.compare("欢迎来到 China 。"); + EXPECT_EQ(ret, 0); +} + +TEST(StringTest, ReverseFractionTest) { + std::string test_str = "3/1"; + std::string out_str = ppspeech::ReverseFraction(test_str); + int ret = out_str.compare("1/3"); + std::cout< + +#include "common/utils/timer.h" + +namespace ppspeech{ + +struct TimerImpl{ + TimerImpl() = default; + virtual ~TimerImpl() = default; + virtual void Reset() = 0; + // time in seconds + virtual double Elapsed() = 0; +}; + +class CpuTimerImpl : public TimerImpl { + public: + CpuTimerImpl() { Reset(); } + + using high_resolution_clock = std::chrono::high_resolution_clock; + + void Reset() override { begin_ = high_resolution_clock::now(); } + + // time in seconds + double Elapsed() override { + auto end = high_resolution_clock::now(); + auto dur = + std::chrono::duration_cast(end - begin_); + return dur.count() / 1000000.0; + } + + private: + high_resolution_clock::time_point begin_; +}; + +Timer::Timer() { + impl_ = std::make_unique(); +} + +Timer::~Timer() = default; + +void Timer::Reset() const { impl_->Reset(); } + +double Timer::Elapsed() const { return impl_->Elapsed(); } + + +} //namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/utils/timer.h b/runtime/engine/common/utils/timer.h new file mode 100644 index 00000000..6f4ae1f8 --- /dev/null +++ b/runtime/engine/common/utils/timer.h @@ -0,0 +1,39 @@ +// Copyright 2020 Xiaomi Corporation (authors: Haowen Qiu) +// Mobvoi Inc. (authors: Fangjun Kuang) +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ppspeech { + +struct TimerImpl; + +class Timer { + public: + Timer(); + ~Timer(); + + void Reset() const; + + // time in seconds + double Elapsed() const; + + private: + std::unique_ptr impl_; +}; + +} //namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/kaldi/CMakeLists.txt b/runtime/engine/kaldi/CMakeLists.txt new file mode 100644 index 00000000..e55cecbb --- /dev/null +++ b/runtime/engine/kaldi/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +) + +add_subdirectory(base) +add_subdirectory(util) +if(WITH_ASR) + add_subdirectory(lat) + add_subdirectory(fstext) + add_subdirectory(decoder) + add_subdirectory(lm) + + add_subdirectory(fstbin) + add_subdirectory(lmbin) +endif() diff --git a/speechx/speechx/kaldi/base/CMakeLists.txt b/runtime/engine/kaldi/base/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/base/CMakeLists.txt rename to runtime/engine/kaldi/base/CMakeLists.txt diff --git a/speechx/speechx/kaldi/base/io-funcs-inl.h b/runtime/engine/kaldi/base/io-funcs-inl.h similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs-inl.h rename to runtime/engine/kaldi/base/io-funcs-inl.h diff --git a/speechx/speechx/kaldi/base/io-funcs.cc b/runtime/engine/kaldi/base/io-funcs.cc similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs.cc rename to runtime/engine/kaldi/base/io-funcs.cc diff --git a/speechx/speechx/kaldi/base/io-funcs.h b/runtime/engine/kaldi/base/io-funcs.h similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs.h rename to runtime/engine/kaldi/base/io-funcs.h diff --git a/speechx/speechx/kaldi/base/kaldi-common.h b/runtime/engine/kaldi/base/kaldi-common.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-common.h rename to runtime/engine/kaldi/base/kaldi-common.h diff --git a/speechx/speechx/kaldi/base/kaldi-error.cc b/runtime/engine/kaldi/base/kaldi-error.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-error.cc rename to runtime/engine/kaldi/base/kaldi-error.cc diff --git a/speechx/speechx/kaldi/base/kaldi-error.h b/runtime/engine/kaldi/base/kaldi-error.h similarity index 99% rename from speechx/speechx/kaldi/base/kaldi-error.h rename to runtime/engine/kaldi/base/kaldi-error.h index a9904a75..98bef74f 100644 --- a/speechx/speechx/kaldi/base/kaldi-error.h +++ b/runtime/engine/kaldi/base/kaldi-error.h @@ -181,7 +181,7 @@ private: // Also see KALDI_COMPILE_TIME_ASSERT, defined in base/kaldi-utils.h, and // KALDI_ASSERT_IS_INTEGER_TYPE and KALDI_ASSERT_IS_FLOATING_TYPE, also defined // there. -#ifndef NDEBUG +#ifdef PPS_DEBUG #define KALDI_ASSERT(cond) \ do { \ if (cond) \ diff --git a/speechx/speechx/kaldi/base/kaldi-math.cc b/runtime/engine/kaldi/base/kaldi-math.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-math.cc rename to runtime/engine/kaldi/base/kaldi-math.cc diff --git a/speechx/speechx/kaldi/base/kaldi-math.h b/runtime/engine/kaldi/base/kaldi-math.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-math.h rename to runtime/engine/kaldi/base/kaldi-math.h diff --git a/speechx/speechx/kaldi/base/kaldi-types.h b/runtime/engine/kaldi/base/kaldi-types.h similarity index 90% rename from speechx/speechx/kaldi/base/kaldi-types.h rename to runtime/engine/kaldi/base/kaldi-types.h index 07381cf2..bf8a2722 100644 --- a/speechx/speechx/kaldi/base/kaldi-types.h +++ b/runtime/engine/kaldi/base/kaldi-types.h @@ -44,7 +44,19 @@ typedef float BaseFloat; #ifndef COMPILE_WITHOUT_OPENFST +#ifdef WITH_ASR #include +#else +using int8 = int8_t; +using int16 = int16_t; +using int32 = int32_t; +using int64 = int64_t; + +using uint8 = uint8_t; +using uint16 = uint16_t; +using uint32 = uint32_t; +using uint64 = uint64_t; +#endif namespace kaldi { using ::int16; diff --git a/speechx/speechx/kaldi/base/kaldi-utils.cc b/runtime/engine/kaldi/base/kaldi-utils.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-utils.cc rename to runtime/engine/kaldi/base/kaldi-utils.cc diff --git a/speechx/speechx/kaldi/base/kaldi-utils.h b/runtime/engine/kaldi/base/kaldi-utils.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-utils.h rename to runtime/engine/kaldi/base/kaldi-utils.h diff --git a/speechx/speechx/kaldi/base/timer.cc b/runtime/engine/kaldi/base/timer.cc similarity index 100% rename from speechx/speechx/kaldi/base/timer.cc rename to runtime/engine/kaldi/base/timer.cc diff --git a/speechx/speechx/kaldi/base/timer.h b/runtime/engine/kaldi/base/timer.h similarity index 100% rename from speechx/speechx/kaldi/base/timer.h rename to runtime/engine/kaldi/base/timer.h diff --git a/speechx/speechx/kaldi/base/version.h b/runtime/engine/kaldi/base/version.h similarity index 100% rename from speechx/speechx/kaldi/base/version.h rename to runtime/engine/kaldi/base/version.h diff --git a/speechx/speechx/kaldi/decoder/CMakeLists.txt b/runtime/engine/kaldi/decoder/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/decoder/CMakeLists.txt rename to runtime/engine/kaldi/decoder/CMakeLists.txt diff --git a/speechx/speechx/kaldi/decoder/decodable-itf.h b/runtime/engine/kaldi/decoder/decodable-itf.h similarity index 100% rename from speechx/speechx/kaldi/decoder/decodable-itf.h rename to runtime/engine/kaldi/decoder/decodable-itf.h diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/runtime/engine/kaldi/decoder/lattice-faster-decoder.cc similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc rename to runtime/engine/kaldi/decoder/lattice-faster-decoder.cc diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/runtime/engine/kaldi/decoder/lattice-faster-decoder.h similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-decoder.h rename to runtime/engine/kaldi/decoder/lattice-faster-decoder.h diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/runtime/engine/kaldi/decoder/lattice-faster-online-decoder.cc similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc rename to runtime/engine/kaldi/decoder/lattice-faster-online-decoder.cc diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/runtime/engine/kaldi/decoder/lattice-faster-online-decoder.h similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h rename to runtime/engine/kaldi/decoder/lattice-faster-online-decoder.h diff --git a/speechx/speechx/kaldi/fstbin/CMakeLists.txt b/runtime/engine/kaldi/fstbin/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/fstbin/CMakeLists.txt rename to runtime/engine/kaldi/fstbin/CMakeLists.txt diff --git a/speechx/speechx/kaldi/fstbin/fstaddselfloops.cc b/runtime/engine/kaldi/fstbin/fstaddselfloops.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstaddselfloops.cc rename to runtime/engine/kaldi/fstbin/fstaddselfloops.cc diff --git a/speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc b/runtime/engine/kaldi/fstbin/fstdeterminizestar.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc rename to runtime/engine/kaldi/fstbin/fstdeterminizestar.cc diff --git a/speechx/speechx/kaldi/fstbin/fstisstochastic.cc b/runtime/engine/kaldi/fstbin/fstisstochastic.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstisstochastic.cc rename to runtime/engine/kaldi/fstbin/fstisstochastic.cc diff --git a/speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc b/runtime/engine/kaldi/fstbin/fstminimizeencoded.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc rename to runtime/engine/kaldi/fstbin/fstminimizeencoded.cc diff --git a/speechx/speechx/kaldi/fstbin/fsttablecompose.cc b/runtime/engine/kaldi/fstbin/fsttablecompose.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fsttablecompose.cc rename to runtime/engine/kaldi/fstbin/fsttablecompose.cc diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/runtime/engine/kaldi/fstext/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/fstext/CMakeLists.txt rename to runtime/engine/kaldi/fstext/CMakeLists.txt diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h b/runtime/engine/kaldi/fstext/determinize-lattice-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-lattice-inl.h rename to runtime/engine/kaldi/fstext/determinize-lattice-inl.h diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice.h b/runtime/engine/kaldi/fstext/determinize-lattice.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-lattice.h rename to runtime/engine/kaldi/fstext/determinize-lattice.h diff --git a/speechx/speechx/kaldi/fstext/determinize-star-inl.h b/runtime/engine/kaldi/fstext/determinize-star-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-star-inl.h rename to runtime/engine/kaldi/fstext/determinize-star-inl.h diff --git a/speechx/speechx/kaldi/fstext/determinize-star.h b/runtime/engine/kaldi/fstext/determinize-star.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-star.h rename to runtime/engine/kaldi/fstext/determinize-star.h diff --git a/speechx/speechx/kaldi/fstext/fstext-lib.h b/runtime/engine/kaldi/fstext/fstext-lib.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-lib.h rename to runtime/engine/kaldi/fstext/fstext-lib.h diff --git a/speechx/speechx/kaldi/fstext/fstext-utils-inl.h b/runtime/engine/kaldi/fstext/fstext-utils-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-utils-inl.h rename to runtime/engine/kaldi/fstext/fstext-utils-inl.h diff --git a/speechx/speechx/kaldi/fstext/fstext-utils.h b/runtime/engine/kaldi/fstext/fstext-utils.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-utils.h rename to runtime/engine/kaldi/fstext/fstext-utils.h diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io-inl.h b/runtime/engine/kaldi/fstext/kaldi-fst-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io-inl.h rename to runtime/engine/kaldi/fstext/kaldi-fst-io-inl.h diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io.cc b/runtime/engine/kaldi/fstext/kaldi-fst-io.cc similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io.cc rename to runtime/engine/kaldi/fstext/kaldi-fst-io.cc diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io.h b/runtime/engine/kaldi/fstext/kaldi-fst-io.h similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io.h rename to runtime/engine/kaldi/fstext/kaldi-fst-io.h diff --git a/speechx/speechx/kaldi/fstext/lattice-utils-inl.h b/runtime/engine/kaldi/fstext/lattice-utils-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-utils-inl.h rename to runtime/engine/kaldi/fstext/lattice-utils-inl.h diff --git a/speechx/speechx/kaldi/fstext/lattice-utils.h b/runtime/engine/kaldi/fstext/lattice-utils.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-utils.h rename to runtime/engine/kaldi/fstext/lattice-utils.h diff --git a/speechx/speechx/kaldi/fstext/lattice-weight.h b/runtime/engine/kaldi/fstext/lattice-weight.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-weight.h rename to runtime/engine/kaldi/fstext/lattice-weight.h diff --git a/speechx/speechx/kaldi/fstext/pre-determinize-inl.h b/runtime/engine/kaldi/fstext/pre-determinize-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/pre-determinize-inl.h rename to runtime/engine/kaldi/fstext/pre-determinize-inl.h diff --git a/speechx/speechx/kaldi/fstext/pre-determinize.h b/runtime/engine/kaldi/fstext/pre-determinize.h similarity index 100% rename from speechx/speechx/kaldi/fstext/pre-determinize.h rename to runtime/engine/kaldi/fstext/pre-determinize.h diff --git a/speechx/speechx/kaldi/fstext/remove-eps-local-inl.h b/runtime/engine/kaldi/fstext/remove-eps-local-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/remove-eps-local-inl.h rename to runtime/engine/kaldi/fstext/remove-eps-local-inl.h diff --git a/speechx/speechx/kaldi/fstext/remove-eps-local.h b/runtime/engine/kaldi/fstext/remove-eps-local.h similarity index 100% rename from speechx/speechx/kaldi/fstext/remove-eps-local.h rename to runtime/engine/kaldi/fstext/remove-eps-local.h diff --git a/speechx/speechx/kaldi/fstext/table-matcher.h b/runtime/engine/kaldi/fstext/table-matcher.h similarity index 100% rename from speechx/speechx/kaldi/fstext/table-matcher.h rename to runtime/engine/kaldi/fstext/table-matcher.h diff --git a/speechx/speechx/kaldi/lat/CMakeLists.txt b/runtime/engine/kaldi/lat/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lat/CMakeLists.txt rename to runtime/engine/kaldi/lat/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc b/runtime/engine/kaldi/lat/determinize-lattice-pruned.cc similarity index 100% rename from speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc rename to runtime/engine/kaldi/lat/determinize-lattice-pruned.cc diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned.h b/runtime/engine/kaldi/lat/determinize-lattice-pruned.h similarity index 100% rename from speechx/speechx/kaldi/lat/determinize-lattice-pruned.h rename to runtime/engine/kaldi/lat/determinize-lattice-pruned.h diff --git a/speechx/speechx/kaldi/lat/kaldi-lattice.cc b/runtime/engine/kaldi/lat/kaldi-lattice.cc similarity index 100% rename from speechx/speechx/kaldi/lat/kaldi-lattice.cc rename to runtime/engine/kaldi/lat/kaldi-lattice.cc diff --git a/speechx/speechx/kaldi/lat/kaldi-lattice.h b/runtime/engine/kaldi/lat/kaldi-lattice.h similarity index 100% rename from speechx/speechx/kaldi/lat/kaldi-lattice.h rename to runtime/engine/kaldi/lat/kaldi-lattice.h diff --git a/speechx/speechx/kaldi/lat/lattice-functions.cc b/runtime/engine/kaldi/lat/lattice-functions.cc similarity index 100% rename from speechx/speechx/kaldi/lat/lattice-functions.cc rename to runtime/engine/kaldi/lat/lattice-functions.cc diff --git a/speechx/speechx/kaldi/lat/lattice-functions.h b/runtime/engine/kaldi/lat/lattice-functions.h similarity index 100% rename from speechx/speechx/kaldi/lat/lattice-functions.h rename to runtime/engine/kaldi/lat/lattice-functions.h diff --git a/speechx/speechx/kaldi/lm/CMakeLists.txt b/runtime/engine/kaldi/lm/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lm/CMakeLists.txt rename to runtime/engine/kaldi/lm/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.cc b/runtime/engine/kaldi/lm/arpa-file-parser.cc similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-file-parser.cc rename to runtime/engine/kaldi/lm/arpa-file-parser.cc diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.h b/runtime/engine/kaldi/lm/arpa-file-parser.h similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-file-parser.h rename to runtime/engine/kaldi/lm/arpa-file-parser.h diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc b/runtime/engine/kaldi/lm/arpa-lm-compiler.cc similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-lm-compiler.cc rename to runtime/engine/kaldi/lm/arpa-lm-compiler.cc diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.h b/runtime/engine/kaldi/lm/arpa-lm-compiler.h similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-lm-compiler.h rename to runtime/engine/kaldi/lm/arpa-lm-compiler.h diff --git a/speechx/speechx/kaldi/lmbin/CMakeLists.txt b/runtime/engine/kaldi/lmbin/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lmbin/CMakeLists.txt rename to runtime/engine/kaldi/lmbin/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lmbin/arpa2fst.cc b/runtime/engine/kaldi/lmbin/arpa2fst.cc similarity index 100% rename from speechx/speechx/kaldi/lmbin/arpa2fst.cc rename to runtime/engine/kaldi/lmbin/arpa2fst.cc diff --git a/speechx/speechx/kaldi/util/CMakeLists.txt b/runtime/engine/kaldi/util/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/util/CMakeLists.txt rename to runtime/engine/kaldi/util/CMakeLists.txt diff --git a/speechx/speechx/kaldi/util/basic-filebuf.h b/runtime/engine/kaldi/util/basic-filebuf.h similarity index 100% rename from speechx/speechx/kaldi/util/basic-filebuf.h rename to runtime/engine/kaldi/util/basic-filebuf.h diff --git a/speechx/speechx/kaldi/util/common-utils.h b/runtime/engine/kaldi/util/common-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/common-utils.h rename to runtime/engine/kaldi/util/common-utils.h diff --git a/speechx/speechx/kaldi/util/const-integer-set-inl.h b/runtime/engine/kaldi/util/const-integer-set-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/const-integer-set-inl.h rename to runtime/engine/kaldi/util/const-integer-set-inl.h diff --git a/speechx/speechx/kaldi/util/const-integer-set.h b/runtime/engine/kaldi/util/const-integer-set.h similarity index 100% rename from speechx/speechx/kaldi/util/const-integer-set.h rename to runtime/engine/kaldi/util/const-integer-set.h diff --git a/speechx/speechx/kaldi/util/edit-distance-inl.h b/runtime/engine/kaldi/util/edit-distance-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/edit-distance-inl.h rename to runtime/engine/kaldi/util/edit-distance-inl.h diff --git a/speechx/speechx/kaldi/util/edit-distance.h b/runtime/engine/kaldi/util/edit-distance.h similarity index 100% rename from speechx/speechx/kaldi/util/edit-distance.h rename to runtime/engine/kaldi/util/edit-distance.h diff --git a/speechx/speechx/kaldi/util/hash-list-inl.h b/runtime/engine/kaldi/util/hash-list-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/hash-list-inl.h rename to runtime/engine/kaldi/util/hash-list-inl.h diff --git a/speechx/speechx/kaldi/util/hash-list.h b/runtime/engine/kaldi/util/hash-list.h similarity index 100% rename from speechx/speechx/kaldi/util/hash-list.h rename to runtime/engine/kaldi/util/hash-list.h diff --git a/speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h b/runtime/engine/kaldi/util/kaldi-cygwin-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h rename to runtime/engine/kaldi/util/kaldi-cygwin-io-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-holder-inl.h b/runtime/engine/kaldi/util/kaldi-holder-inl.h similarity index 85% rename from speechx/speechx/kaldi/util/kaldi-holder-inl.h rename to runtime/engine/kaldi/util/kaldi-holder-inl.h index 134cdd93..9b441ad4 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder-inl.h +++ b/runtime/engine/kaldi/util/kaldi-holder-inl.h @@ -754,53 +754,53 @@ class TokenVectorHolder { }; -class HtkMatrixHolder { - public: - typedef std::pair, HtkHeader> T; - - HtkMatrixHolder() {} - - static bool Write(std::ostream &os, bool binary, const T &t) { - if (!binary) - KALDI_ERR << "Non-binary HTK-format write not supported."; - bool ans = WriteHtk(os, t.first, t.second); - if (!ans) - KALDI_WARN << "Error detected writing HTK-format matrix."; - return ans; - } - - void Clear() { t_.first.Resize(0, 0); } - - // Reads into the holder. - bool Read(std::istream &is) { - bool ans = ReadHtk(is, &t_.first, &t_.second); - if (!ans) { - KALDI_WARN << "Error detected reading HTK-format matrix."; - return false; - } - return ans; - } - - // HTK-format matrices only read in binary. - static bool IsReadInBinary() { return true; } - - T &Value() { return t_; } - - void Swap(HtkMatrixHolder *other) { - t_.first.Swap(&(other->t_.first)); - std::swap(t_.second, other->t_.second); - } - - bool ExtractRange(const HtkMatrixHolder &other, - const std::string &range) { - KALDI_ERR << "ExtractRange is not defined for this type of holder."; - return false; - } - // Default destructor. - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); - T t_; -}; +//class HtkMatrixHolder { + //public: + //typedef std::pair, HtkHeader> T; + + //HtkMatrixHolder() {} + + //static bool Write(std::ostream &os, bool binary, const T &t) { + //if (!binary) + //KALDI_ERR << "Non-binary HTK-format write not supported."; + //bool ans = WriteHtk(os, t.first, t.second); + //if (!ans) + //KALDI_WARN << "Error detected writing HTK-format matrix."; + //return ans; + //} + + //void Clear() { t_.first.Resize(0, 0); } + + //// Reads into the holder. + //bool Read(std::istream &is) { + //bool ans = ReadHtk(is, &t_.first, &t_.second); + //if (!ans) { + //KALDI_WARN << "Error detected reading HTK-format matrix."; + //return false; + //} + //return ans; + //} + + //// HTK-format matrices only read in binary. + //static bool IsReadInBinary() { return true; } + + //T &Value() { return t_; } + + //void Swap(HtkMatrixHolder *other) { + //t_.first.Swap(&(other->t_.first)); + //std::swap(t_.second, other->t_.second); + //} + + //bool ExtractRange(const HtkMatrixHolder &other, + //const std::string &range) { + //KALDI_ERR << "ExtractRange is not defined for this type of holder."; + //return false; + //} + //// Default destructor. + //private: + //KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); + //T t_; +//}; // SphinxMatrixHolder can be used to read and write feature files in // CMU Sphinx format. 13-dimensional big-endian features are assumed. @@ -813,104 +813,104 @@ class HtkMatrixHolder { // be no problem, because the usage help of Sphinx' "wave2feat" for example // says that Sphinx features are always big endian. // Note: the kFeatDim defaults to 13, see forward declaration in kaldi-holder.h -template class SphinxMatrixHolder { - public: - typedef Matrix T; - - SphinxMatrixHolder() {} - - void Clear() { feats_.Resize(0, 0); } - - // Writes Sphinx-format features - static bool Write(std::ostream &os, bool binary, const T &m) { - if (!binary) { - KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; - return false; - } - - int32 size = m.NumRows() * m.NumCols(); - if (MachineIsLittleEndian()) - KALDI_SWAP4(size); - // write the header - os.write(reinterpret_cast (&size), sizeof(size)); - - for (MatrixIndexT i = 0; i < m.NumRows(); i++) { - std::vector tmp(m.NumCols()); - for (MatrixIndexT j = 0; j < m.NumCols(); j++) { - tmp[j] = static_cast(m(i, j)); - if (MachineIsLittleEndian()) - KALDI_SWAP4(tmp[j]); - } - os.write(reinterpret_cast(&(tmp[0])), - tmp.size() * 4); - } - return true; - } - - // Reads the features into a Kaldi Matrix - bool Read(std::istream &is) { - int32 nmfcc; - - is.read(reinterpret_cast (&nmfcc), sizeof(nmfcc)); - if (MachineIsLittleEndian()) - KALDI_SWAP4(nmfcc); - KALDI_VLOG(2) << "#feats: " << nmfcc; - int32 nfvec = nmfcc / kFeatDim; - if ((nmfcc % kFeatDim) != 0) { - KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; - return false; - } - - feats_.Resize(nfvec, kFeatDim); - for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { - if (sizeof(BaseFloat) == sizeof(float32)) { - is.read(reinterpret_cast (feats_.RowData(i)), - kFeatDim * sizeof(float32)); - if (!is.good()) { - KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; - return false; - } - if (MachineIsLittleEndian()) { - for (MatrixIndexT j = 0; j < kFeatDim; j++) - KALDI_SWAP4(feats_(i, j)); - } - } else { // KALDI_DOUBLEPRECISION=1 - float32 tmp[kFeatDim]; - is.read(reinterpret_cast (tmp), sizeof(tmp)); - if (!is.good()) { - KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; - return false; - } - for (MatrixIndexT j = 0; j < kFeatDim; j++) { - if (MachineIsLittleEndian()) - KALDI_SWAP4(tmp[j]); - feats_(i, j) = static_cast(tmp[j]); - } - } - } - - return true; - } - - // Only read in binary - static bool IsReadInBinary() { return true; } - - T &Value() { return feats_; } - - void Swap(SphinxMatrixHolder *other) { - feats_.Swap(&(other->feats_)); - } - - bool ExtractRange(const SphinxMatrixHolder &other, - const std::string &range) { - KALDI_ERR << "ExtractRange is not defined for this type of holder."; - return false; - } - - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); - T feats_; -}; +//template class SphinxMatrixHolder { + //public: + //typedef Matrix T; + + //SphinxMatrixHolder() {} + + //void Clear() { feats_.Resize(0, 0); } + + //// Writes Sphinx-format features + //static bool Write(std::ostream &os, bool binary, const T &m) { + //if (!binary) { + //KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; + //return false; + //} + + //int32 size = m.NumRows() * m.NumCols(); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(size); + //// write the header + //os.write(reinterpret_cast (&size), sizeof(size)); + + //for (MatrixIndexT i = 0; i < m.NumRows(); i++) { + //std::vector tmp(m.NumCols()); + //for (MatrixIndexT j = 0; j < m.NumCols(); j++) { + //tmp[j] = static_cast(m(i, j)); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(tmp[j]); + //} + //os.write(reinterpret_cast(&(tmp[0])), + //tmp.size() * 4); + //} + //return true; + //} + + //// Reads the features into a Kaldi Matrix + //bool Read(std::istream &is) { + //int32 nmfcc; + + //is.read(reinterpret_cast (&nmfcc), sizeof(nmfcc)); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(nmfcc); + //KALDI_VLOG(2) << "#feats: " << nmfcc; + //int32 nfvec = nmfcc / kFeatDim; + //if ((nmfcc % kFeatDim) != 0) { + //KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; + //return false; + //} + + //feats_.Resize(nfvec, kFeatDim); + //for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { + //if (sizeof(BaseFloat) == sizeof(float32)) { + //is.read(reinterpret_cast (feats_.RowData(i)), + //kFeatDim * sizeof(float32)); + //if (!is.good()) { + //KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + //return false; + //} + //if (MachineIsLittleEndian()) { + //for (MatrixIndexT j = 0; j < kFeatDim; j++) + //KALDI_SWAP4(feats_(i, j)); + //} + //} else { // KALDI_DOUBLEPRECISION=1 + //float32 tmp[kFeatDim]; + //is.read(reinterpret_cast (tmp), sizeof(tmp)); + //if (!is.good()) { + //KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + //return false; + //} + //for (MatrixIndexT j = 0; j < kFeatDim; j++) { + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(tmp[j]); + //feats_(i, j) = static_cast(tmp[j]); + //} + //} + //} + + //return true; + //} + + //// Only read in binary + //static bool IsReadInBinary() { return true; } + + //T &Value() { return feats_; } + + //void Swap(SphinxMatrixHolder *other) { + //feats_.Swap(&(other->feats_)); + //} + + //bool ExtractRange(const SphinxMatrixHolder &other, + //const std::string &range) { + //KALDI_ERR << "ExtractRange is not defined for this type of holder."; + //return false; + //} + + //private: + //KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); + //T feats_; +//}; /// @} end "addtogroup holders" diff --git a/speechx/speechx/kaldi/util/kaldi-holder.cc b/runtime/engine/kaldi/util/kaldi-holder.cc similarity index 99% rename from speechx/speechx/kaldi/util/kaldi-holder.cc rename to runtime/engine/kaldi/util/kaldi-holder.cc index 577679ef..6b0eebb9 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder.cc +++ b/runtime/engine/kaldi/util/kaldi-holder.cc @@ -85,7 +85,7 @@ bool ParseMatrixRangeSpecifier(const std::string &range, return status; } -bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, +/*bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, GeneralMatrix *output) { // We just inspect input's type and forward to the correct implementation // if available. For kSparseMatrix, we do just fairly inefficient conversion @@ -135,6 +135,7 @@ template bool ExtractObjectRange(const CompressedMatrix &, const std::string &, template bool ExtractObjectRange(const CompressedMatrix &, const std::string &, Matrix *); +*/ template bool ExtractObjectRange(const Matrix &input, const std::string &range, Matrix *output) { diff --git a/speechx/speechx/kaldi/util/kaldi-holder.h b/runtime/engine/kaldi/util/kaldi-holder.h similarity index 96% rename from speechx/speechx/kaldi/util/kaldi-holder.h rename to runtime/engine/kaldi/util/kaldi-holder.h index f495f27f..a8c42c9f 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder.h +++ b/runtime/engine/kaldi/util/kaldi-holder.h @@ -27,7 +27,6 @@ #include "util/kaldi-io.h" #include "util/text-utils.h" #include "matrix/kaldi-vector.h" -#include "matrix/sparse-matrix.h" namespace kaldi { @@ -214,10 +213,10 @@ class TokenVectorHolder; /// A class for reading/writing HTK-format matrices. /// T == std::pair, HtkHeader> -class HtkMatrixHolder; +//class HtkMatrixHolder; /// A class for reading/writing Sphinx format matrices. -template class SphinxMatrixHolder; +//template class SphinxMatrixHolder; /// This templated function exists so that we can write .scp files with /// 'object ranges' specified: the canonical example is a [first:last] range @@ -249,15 +248,15 @@ bool ExtractObjectRange(const Vector &input, const std::string &range, Vector *output); /// GeneralMatrix is always of type BaseFloat -bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, - GeneralMatrix *output); +//bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, + // GeneralMatrix *output); /// CompressedMatrix is always of the type BaseFloat but it is more /// efficient to provide template as it uses CompressedMatrix's own /// conversion to Matrix -template -bool ExtractObjectRange(const CompressedMatrix &input, const std::string &range, - Matrix *output); +//template +//bool ExtractObjectRange(const CompressedMatrix &input, const std::string &range, + // Matrix *output); // In SequentialTableReaderScriptImpl and RandomAccessTableReaderScriptImpl, for // cases where the scp contained 'range specifiers' (things in square brackets diff --git a/speechx/speechx/kaldi/util/kaldi-io-inl.h b/runtime/engine/kaldi/util/kaldi-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io-inl.h rename to runtime/engine/kaldi/util/kaldi-io-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-io.cc b/runtime/engine/kaldi/util/kaldi-io.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io.cc rename to runtime/engine/kaldi/util/kaldi-io.cc diff --git a/speechx/speechx/kaldi/util/kaldi-io.h b/runtime/engine/kaldi/util/kaldi-io.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io.h rename to runtime/engine/kaldi/util/kaldi-io.h diff --git a/speechx/speechx/kaldi/util/kaldi-pipebuf.h b/runtime/engine/kaldi/util/kaldi-pipebuf.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-pipebuf.h rename to runtime/engine/kaldi/util/kaldi-pipebuf.h diff --git a/speechx/speechx/kaldi/util/kaldi-semaphore.cc b/runtime/engine/kaldi/util/kaldi-semaphore.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-semaphore.cc rename to runtime/engine/kaldi/util/kaldi-semaphore.cc diff --git a/speechx/speechx/kaldi/util/kaldi-semaphore.h b/runtime/engine/kaldi/util/kaldi-semaphore.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-semaphore.h rename to runtime/engine/kaldi/util/kaldi-semaphore.h diff --git a/speechx/speechx/kaldi/util/kaldi-table-inl.h b/runtime/engine/kaldi/util/kaldi-table-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-table-inl.h rename to runtime/engine/kaldi/util/kaldi-table-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-table.cc b/runtime/engine/kaldi/util/kaldi-table.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-table.cc rename to runtime/engine/kaldi/util/kaldi-table.cc diff --git a/speechx/speechx/kaldi/util/kaldi-table.h b/runtime/engine/kaldi/util/kaldi-table.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-table.h rename to runtime/engine/kaldi/util/kaldi-table.h diff --git a/speechx/speechx/kaldi/util/kaldi-thread.cc b/runtime/engine/kaldi/util/kaldi-thread.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-thread.cc rename to runtime/engine/kaldi/util/kaldi-thread.cc diff --git a/speechx/speechx/kaldi/util/kaldi-thread.h b/runtime/engine/kaldi/util/kaldi-thread.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-thread.h rename to runtime/engine/kaldi/util/kaldi-thread.h diff --git a/speechx/speechx/kaldi/util/options-itf.h b/runtime/engine/kaldi/util/options-itf.h similarity index 100% rename from speechx/speechx/kaldi/util/options-itf.h rename to runtime/engine/kaldi/util/options-itf.h diff --git a/speechx/speechx/kaldi/util/parse-options.cc b/runtime/engine/kaldi/util/parse-options.cc similarity index 100% rename from speechx/speechx/kaldi/util/parse-options.cc rename to runtime/engine/kaldi/util/parse-options.cc diff --git a/speechx/speechx/kaldi/util/parse-options.h b/runtime/engine/kaldi/util/parse-options.h similarity index 100% rename from speechx/speechx/kaldi/util/parse-options.h rename to runtime/engine/kaldi/util/parse-options.h diff --git a/speechx/speechx/kaldi/util/simple-io-funcs.cc b/runtime/engine/kaldi/util/simple-io-funcs.cc similarity index 100% rename from speechx/speechx/kaldi/util/simple-io-funcs.cc rename to runtime/engine/kaldi/util/simple-io-funcs.cc diff --git a/speechx/speechx/kaldi/util/simple-io-funcs.h b/runtime/engine/kaldi/util/simple-io-funcs.h similarity index 100% rename from speechx/speechx/kaldi/util/simple-io-funcs.h rename to runtime/engine/kaldi/util/simple-io-funcs.h diff --git a/speechx/speechx/kaldi/util/simple-options.cc b/runtime/engine/kaldi/util/simple-options.cc similarity index 100% rename from speechx/speechx/kaldi/util/simple-options.cc rename to runtime/engine/kaldi/util/simple-options.cc diff --git a/speechx/speechx/kaldi/util/simple-options.h b/runtime/engine/kaldi/util/simple-options.h similarity index 100% rename from speechx/speechx/kaldi/util/simple-options.h rename to runtime/engine/kaldi/util/simple-options.h diff --git a/speechx/speechx/kaldi/util/stl-utils.h b/runtime/engine/kaldi/util/stl-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/stl-utils.h rename to runtime/engine/kaldi/util/stl-utils.h diff --git a/speechx/speechx/kaldi/util/table-types.h b/runtime/engine/kaldi/util/table-types.h similarity index 69% rename from speechx/speechx/kaldi/util/table-types.h rename to runtime/engine/kaldi/util/table-types.h index efcdf1b5..665a1327 100644 --- a/speechx/speechx/kaldi/util/table-types.h +++ b/runtime/engine/kaldi/util/table-types.h @@ -23,7 +23,8 @@ #include "base/kaldi-common.h" #include "util/kaldi-table.h" #include "util/kaldi-holder.h" -#include "matrix/matrix-lib.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/kaldi-vector.h" namespace kaldi { @@ -51,8 +52,8 @@ typedef RandomAccessTableReader > > typedef RandomAccessTableReaderMapped > > RandomAccessDoubleMatrixReaderMapped; -typedef TableWriter > - CompressedMatrixWriter; +//typedef TableWriter > + //CompressedMatrixWriter; typedef TableWriter > > BaseFloatVectorWriter; @@ -70,39 +71,39 @@ typedef SequentialTableReader > > typedef RandomAccessTableReader > > RandomAccessDoubleVectorReader; -typedef TableWriter > > - BaseFloatCuMatrixWriter; -typedef SequentialTableReader > > - SequentialBaseFloatCuMatrixReader; -typedef RandomAccessTableReader > > - RandomAccessBaseFloatCuMatrixReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessBaseFloatCuMatrixReaderMapped; - -typedef TableWriter > > - DoubleCuMatrixWriter; -typedef SequentialTableReader > > - SequentialDoubleCuMatrixReader; -typedef RandomAccessTableReader > > - RandomAccessDoubleCuMatrixReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessDoubleCuMatrixReaderMapped; - -typedef TableWriter > > - BaseFloatCuVectorWriter; -typedef SequentialTableReader > > - SequentialBaseFloatCuVectorReader; -typedef RandomAccessTableReader > > - RandomAccessBaseFloatCuVectorReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessBaseFloatCuVectorReaderMapped; - -typedef TableWriter > > - DoubleCuVectorWriter; -typedef SequentialTableReader > > - SequentialDoubleCuVectorReader; -typedef RandomAccessTableReader > > - RandomAccessDoubleCuVectorReader; +//typedef TableWriter > > + //BaseFloatCuMatrixWriter; +//typedef SequentialTableReader > > + //SequentialBaseFloatCuMatrixReader; +//typedef RandomAccessTableReader > > + //RandomAccessBaseFloatCuMatrixReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessBaseFloatCuMatrixReaderMapped; + +//typedef TableWriter > > + //DoubleCuMatrixWriter; +//typedef SequentialTableReader > > + //SequentialDoubleCuMatrixReader; +//typedef RandomAccessTableReader > > + //RandomAccessDoubleCuMatrixReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessDoubleCuMatrixReaderMapped; + +//typedef TableWriter > > + //BaseFloatCuVectorWriter; +//typedef SequentialTableReader > > + //SequentialBaseFloatCuVectorReader; +//typedef RandomAccessTableReader > > + //RandomAccessBaseFloatCuVectorReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessBaseFloatCuVectorReaderMapped; + +//typedef TableWriter > > + //DoubleCuVectorWriter; +//typedef SequentialTableReader > > + //SequentialDoubleCuVectorReader; +//typedef RandomAccessTableReader > > + //RandomAccessDoubleCuVectorReader; typedef TableWriter > Int32Writer; @@ -150,8 +151,6 @@ typedef TableWriter > BoolWriter; typedef SequentialTableReader > SequentialBoolReader; typedef RandomAccessTableReader > RandomAccessBoolReader; - - /// TokenWriter is a writer specialized for std::string where the strings /// are nonempty and whitespace-free. T == std::string typedef TableWriter TokenWriter; @@ -169,14 +168,14 @@ typedef RandomAccessTableReader RandomAccessTokenVectorReader; -typedef TableWriter > - GeneralMatrixWriter; -typedef SequentialTableReader > - SequentialGeneralMatrixReader; -typedef RandomAccessTableReader > - RandomAccessGeneralMatrixReader; -typedef RandomAccessTableReaderMapped > - RandomAccessGeneralMatrixReaderMapped; +//typedef TableWriter > +// GeneralMatrixWriter; +//typedef SequentialTableReader > + // SequentialGeneralMatrixReader; +//typedef RandomAccessTableReader > + // RandomAccessGeneralMatrixReader; +//typedef RandomAccessTableReaderMapped > + // RandomAccessGeneralMatrixReaderMapped; diff --git a/speechx/speechx/kaldi/util/text-utils.cc b/runtime/engine/kaldi/util/text-utils.cc similarity index 100% rename from speechx/speechx/kaldi/util/text-utils.cc rename to runtime/engine/kaldi/util/text-utils.cc diff --git a/speechx/speechx/kaldi/util/text-utils.h b/runtime/engine/kaldi/util/text-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/text-utils.h rename to runtime/engine/kaldi/util/text-utils.h diff --git a/runtime/engine/vad/CMakeLists.txt b/runtime/engine/vad/CMakeLists.txt new file mode 100644 index 00000000..442acbd8 --- /dev/null +++ b/runtime/engine/vad/CMakeLists.txt @@ -0,0 +1,5 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +add_subdirectory(nnet) + +add_subdirectory(interface) \ No newline at end of file diff --git a/runtime/engine/vad/frontend/wav.h b/runtime/engine/vad/frontend/wav.h new file mode 100644 index 00000000..f9b7bee2 --- /dev/null +++ b/runtime/engine/vad/frontend/wav.h @@ -0,0 +1,199 @@ +// Copyright (c) 2016 Personal (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +namespace wav { + +struct WavHeader { + char riff[4]; // "riff" + unsigned int size; + char wav[4]; // "WAVE" + char fmt[4]; // "fmt " + unsigned int fmt_size; + uint16_t format; + uint16_t channels; + unsigned int sample_rate; + unsigned int bytes_per_second; + uint16_t block_size; + uint16_t bit; + char data[4]; // "data" + unsigned int data_size; +}; + +class WavReader { + public: + WavReader() : data_(nullptr) {} + explicit WavReader(const std::string& filename) { Open(filename); } + + bool Open(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "rb"); + if (NULL == fp) { + std::cout << "Error in read " << filename; + return false; + } + + WavHeader header; + fread(&header, 1, sizeof(header), fp); + if (header.fmt_size < 16) { + fprintf(stderr, + "WaveData: expect PCM format data " + "to have fmt chunk of at least size 16.\n"); + return false; + } else if (header.fmt_size > 16) { + int offset = 44 - 8 + header.fmt_size - 16; + fseek(fp, offset, SEEK_SET); + fread(header.data, 8, sizeof(char), fp); + } + // check "riff" "WAVE" "fmt " "data" + + // Skip any sub-chunks between "fmt" and "data". Usually there will + // be a single "fact" sub chunk, but on Windows there can also be a + // "list" sub chunk. + while (0 != strncmp(header.data, "data", 4)) { + // We will just ignore the data in these chunks. + fseek(fp, header.data_size, SEEK_CUR); + // read next sub chunk + fread(header.data, 8, sizeof(char), fp); + } + + num_channel_ = header.channels; + sample_rate_ = header.sample_rate; + bits_per_sample_ = header.bit; + int num_data = header.data_size / (bits_per_sample_ / 8); + data_ = new float[num_data]; // Create 1-dim array + num_samples_ = num_data / num_channel_; + + for (int i = 0; i < num_data; ++i) { + switch (bits_per_sample_) { + case 8: { + char sample; + fread(&sample, 1, sizeof(char), fp); + data_[i] = static_cast(sample); + break; + } + case 16: { + int16_t sample; + fread(&sample, 1, sizeof(int16_t), fp); + // std::cout << sample; + data_[i] = static_cast(sample); + // std::cout << data_[i]; + break; + } + case 32: { + int sample; + fread(&sample, 1, sizeof(int), fp); + data_[i] = static_cast(sample); + break; + } + default: + fprintf(stderr, "unsupported quantization bits"); + exit(1); + } + } + fclose(fp); + return true; + } + + int num_channel() const { return num_channel_; } + int sample_rate() const { return sample_rate_; } + int bits_per_sample() const { return bits_per_sample_; } + int num_samples() const { return num_samples_; } + const float* data() const { return data_; } + + private: + int num_channel_; + int sample_rate_; + int bits_per_sample_; + int num_samples_; // sample points per channel + float* data_; +}; + +class WavWriter { + public: + WavWriter(const float* data, + int num_samples, + int num_channel, + int sample_rate, + int bits_per_sample) + : data_(data), + num_samples_(num_samples), + num_channel_(num_channel), + sample_rate_(sample_rate), + bits_per_sample_(bits_per_sample) {} + + void Write(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "w"); + // init char 'riff' 'WAVE' 'fmt ' 'data' + WavHeader header; + char wav_header[44] = { + 0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, + 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; + memcpy(&header, wav_header, sizeof(header)); + header.channels = num_channel_; + header.bit = bits_per_sample_; + header.sample_rate = sample_rate_; + header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8); + header.size = sizeof(header) - 8 + header.data_size; + header.bytes_per_second = + sample_rate_ * num_channel_ * (bits_per_sample_ / 8); + header.block_size = num_channel_ * (bits_per_sample_ / 8); + + fwrite(&header, 1, sizeof(header), fp); + + for (int i = 0; i < num_samples_; ++i) { + for (int j = 0; j < num_channel_; ++j) { + switch (bits_per_sample_) { + case 8: { + char sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 16: { + int16_t sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 32: { + int sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + } + } + } + fclose(fp); + } + + private: + const float* data_; + int num_samples_; // total float points in data_ + int num_channel_; + int sample_rate_; + int bits_per_sample_; +}; + +} // namespace wav diff --git a/runtime/engine/vad/interface/CMakeLists.txt b/runtime/engine/vad/interface/CMakeLists.txt new file mode 100644 index 00000000..e056ec39 --- /dev/null +++ b/runtime/engine/vad/interface/CMakeLists.txt @@ -0,0 +1,24 @@ +set(srcs + vad_interface.cc +) + +add_library(pps_vad_interface SHARED ${srcs}) +target_link_libraries(pps_vad_interface PUBLIC pps_vad extern_glog) + + +set(bin_name vad_interface_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} pps_vad_interface) +# set_target_properties(${bin_name} PROPERTIES PUBLIC_HEADER "vad_interface.h;../frontend/wav.h") + +file(RELATIVE_PATH DEST_DIR ${ENGINE_ROOT} ${CMAKE_CURRENT_SOURCE_DIR}) +install(TARGETS pps_vad_interface DESTINATION lib) +install(FILES vad_interface.h DESTINATION include/${DEST_DIR}) + +install(TARGETS vad_interface_main + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + PUBLIC_HEADER DESTINATION include/${DEST_DIR} +) +install(FILES vad_interface_main.cc DESTINATION demo/${DEST_DIR}) \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface.cc b/runtime/engine/vad/interface/vad_interface.cc new file mode 100644 index 00000000..2e5c9175 --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "vad/interface/vad_interface.h" + +#include "common/base/config.h" +#include "vad/nnet/vad.h" + + +PPSHandle_t PPSVadCreateInstance(const char* conf_path) { + Config conf(conf_path); + ppspeech::VadNnetConf nnet_conf; + nnet_conf.sr = conf.Read("sr", 16000); + nnet_conf.frame_ms = conf.Read("frame_ms", 32); + nnet_conf.threshold = conf.Read("threshold", 0.45f); + nnet_conf.beam = conf.Read("beam", 0.15f); + nnet_conf.min_silence_duration_ms = + conf.Read("min_silence_duration_ms", 200); + nnet_conf.speech_pad_left_ms = conf.Read("speech_pad_left_ms", 0); + nnet_conf.speech_pad_right_ms = conf.Read("speech_pad_right_ms", 0); + + nnet_conf.model_file_path = conf.Read("model_path", std::string("")); + nnet_conf.param_file_path = conf.Read("param_path", std::string("")); + nnet_conf.num_cpu_thread = conf.Read("num_cpu_thread", 1); + + ppspeech::Vad* model = new ppspeech::Vad(nnet_conf.model_file_path); + + // custom config, but must be set before init + model->SetConfig(nnet_conf); + model->Init(); + + return static_cast(model); +} + + +int PPSVadDestroyInstance(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model != nullptr) { + delete model; + model = nullptr; + } + return 0; +} + +int PPSVadChunkSizeSamples(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return -1; + } + + return model->WindowSizeSamples(); +} + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return PPS_VAD_ILLEGAL; + } + + std::vector chunk_in(chunk, chunk + num_element); + if (!model->ForwardChunk(chunk_in)) { + printf("forward chunk failed\n"); + return PPS_VAD_ILLEGAL; + } + ppspeech::Vad::State s = model->Postprocess(); + PPSVadState_t ret = (PPSVadState_t)s; + return ret; +} + +int PPSVadReset(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return -1; + } + model->Reset(); + return 0; +} + +int PPSVadGetResult(PPSHandle_t instance, char* result, int max_len){ + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return -1; + } + return model->GetResult(result, max_len); +}; \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface.h b/runtime/engine/vad/interface/vad_interface.h new file mode 100644 index 00000000..15d0b811 --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface.h @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* PPSHandle_t; + +typedef enum { + PPS_VAD_ILLEGAL = 0, // error + PPS_VAD_SIL, // silence + PPS_VAD_START, // start speech + PPS_VAD_SPEECH, // in speech + PPS_VAD_END, // end speech + PPS_VAD_NUMSTATES, // number of states +} PPSVadState_t; + +PPSHandle_t PPSVadCreateInstance(const char* conf_path); + +int PPSVadDestroyInstance(PPSHandle_t instance); + +int PPSVadReset(PPSHandle_t instance); + +int PPSVadChunkSizeSamples(PPSHandle_t instance); + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element); + +int PPSVadGetResult(PPSHandle_t instance, char* result, int max_len); +#ifdef __cplusplus +} +#endif // __cplusplus \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface_main.cc b/runtime/engine/vad/interface/vad_interface_main.cc new file mode 100644 index 00000000..6dba794d --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface_main.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include + +#include +#include "common/base/common.h" +#include "vad/frontend/wav.h" +#include "vad/interface/vad_interface.h" + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout << "Usage: vad_interface_main path/to/config wav.scp" + "run_option, " + "e.g ./vad_interface_main config wav.scp" + << std::endl; + return -1; + } + + std::string config_path = argv[1]; + std::string wav_scp = argv[2]; + + PPSHandle_t handle = PPSVadCreateInstance(config_path.c_str()); + + std::ifstream fp_wav(wav_scp); + std::string line = ""; + while(getline(fp_wav, line)){ + std::vector inputWav; // [0, 1] + wav::WavReader wav_reader = wav::WavReader(line); + auto sr = wav_reader.sample_rate(); + CHECK(sr == 16000) << " sr is " << sr << " expect 16000"; + + auto num_samples = wav_reader.num_samples(); + inputWav.resize(num_samples); + for (int i = 0; i < num_samples; i++) { + inputWav[i] = wav_reader.data()[i] / 32768; + } + + ppspeech::Timer timer; + int window_size_samples = PPSVadChunkSizeSamples(handle); + for (int64_t j = 0; j < num_samples; j += window_size_samples) { + auto start = j; + auto end = start + window_size_samples >= num_samples + ? num_samples + : start + window_size_samples; + std::vector r(window_size_samples, 0); + auto current_chunk_size = end - start; + memcpy(r.data(), inputWav.data() + start, current_chunk_size * sizeof(float)); + + PPSVadState_t s = PPSVadFeedForward(handle, r.data(), r.size()); + } + + std::cout << "RTF=" << timer.Elapsed() / double(num_samples / sr) + << std::endl; + + char result[10240] = {0}; + PPSVadGetResult(handle, result, 10240); + std::cout << line << " " << result << std::endl; + + PPSVadReset(handle); + // getchar(); + } + PPSVadDestroyInstance(handle); + return 0; +} diff --git a/runtime/engine/vad/nnet/CMakeLists.txt b/runtime/engine/vad/nnet/CMakeLists.txt new file mode 100644 index 00000000..3ca951d9 --- /dev/null +++ b/runtime/engine/vad/nnet/CMakeLists.txt @@ -0,0 +1,19 @@ +set(srcs + vad.cc +) + +add_library(pps_vad ${srcs}) +target_link_libraries(pps_vad PUBLIC ${FASTDEPLOY_LIBS} common extern_glog) + + +set(bin_name vad_nnet_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} pps_vad) + +file(RELATIVE_PATH DEST_DIR ${ENGINE_ROOT} ${CMAKE_CURRENT_SOURCE_DIR}) +install(TARGETS pps_vad DESTINATION lib) +if(ANDROID) + install(TARGETS extern_glog DESTINATION lib) +else() # UNIX + install(TARGETS glog DESTINATION lib) +endif() diff --git a/runtime/engine/vad/nnet/vad.cc b/runtime/engine/vad/nnet/vad.cc new file mode 100644 index 00000000..101f2370 --- /dev/null +++ b/runtime/engine/vad/nnet/vad.cc @@ -0,0 +1,333 @@ +// Copyright (c) 2023 Chen Qianhe Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "vad/nnet/vad.h" + +#include +#include + +#include "common/base/common.h" + + +namespace ppspeech { + +Vad::Vad(const std::string& model_file, + const fastdeploy::RuntimeOption& + custom_option /* = fastdeploy::RuntimeOption() */) { + valid_cpu_backends = {fastdeploy::Backend::ORT, + fastdeploy::Backend::OPENVINO}; + valid_gpu_backends = {fastdeploy::Backend::ORT, fastdeploy::Backend::TRT}; + + runtime_option = custom_option; + // ORT backend + runtime_option.UseCpu(); + runtime_option.UseOrtBackend(); + runtime_option.model_format = fastdeploy::ModelFormat::ONNX; + // grap opt level + runtime_option.ort_option.graph_optimization_level = 99; + // one-thread + runtime_option.ort_option.intra_op_num_threads = 1; + runtime_option.ort_option.inter_op_num_threads = 1; + // model path + runtime_option.model_file = model_file; +} + +void Vad::Init() { + std::lock_guard lock(init_lock_); + Initialize(); +} + +std::string Vad::ModelName() const { return "VAD"; } + +void Vad::SetConfig(const VadNnetConf conf) { + SetConfig(conf.sr, + conf.frame_ms, + conf.threshold, + conf.beam, + conf.min_silence_duration_ms, + conf.speech_pad_left_ms, + conf.speech_pad_right_ms); +} + +void Vad::SetConfig(const int& sr, + const int& frame_ms, + const float& threshold, + const float& beam, + const int& min_silence_duration_ms, + const int& speech_pad_left_ms, + const int& speech_pad_right_ms) { + if (initialized_) { + fastdeploy::FDERROR << "SetConfig must be called before init" + << std::endl; + throw std::runtime_error("SetConfig must be called before init"); + } + sample_rate_ = sr; + sr_per_ms_ = sr / 1000; + threshold_ = threshold; + beam_ = beam; + frame_ms_ = frame_ms; + min_silence_samples_ = min_silence_duration_ms * sr_per_ms_; + speech_pad_left_samples_ = speech_pad_left_ms * sr_per_ms_; + speech_pad_right_samples_ = speech_pad_right_ms * sr_per_ms_; + + // init chunk size + window_size_samples_ = frame_ms * sr_per_ms_; + current_chunk_size_ = window_size_samples_; + + fastdeploy::FDINFO << "sr=" << sr_per_ms_ << " threshold=" << threshold_ + << " beam=" << beam_ << " frame_ms=" << frame_ms_ + << " min_silence_duration_ms=" << min_silence_duration_ms + << " speech_pad_left_ms=" << speech_pad_left_ms + << " speech_pad_right_ms=" << speech_pad_right_ms; +} + +void Vad::Reset() { + std::memset(h_.data(), 0.0f, h_.size() * sizeof(float)); + std::memset(c_.data(), 0.0f, c_.size() * sizeof(float)); + + triggerd_ = false; + temp_end_ = 0; + current_sample_ = 0; + + speechStart_.clear(); + speechEnd_.clear(); + + states_.clear(); +} + +bool Vad::Initialize() { + // input & output holder + inputTensors_.resize(4); + outputTensors_.resize(3); + + // input shape + input_node_dims_.emplace_back(1); + input_node_dims_.emplace_back(window_size_samples_); + // sr buffer + sr_.resize(1); + sr_[0] = sample_rate_; + // hidden state buffer + h_.resize(size_hc_); + c_.resize(size_hc_); + + Reset(); + + + // InitRuntime + if (!InitRuntime()) { + fastdeploy::FDERROR << "Failed to initialize fastdeploy backend." + << std::endl; + return false; + } + + initialized_ = true; + + + fastdeploy::FDINFO << "init done."; + return true; +} + +bool Vad::ForwardChunk(std::vector& chunk) { + // last chunk may not be window_size_samples_ + input_node_dims_.back() = chunk.size(); + assert(window_size_samples_ >= chunk.size()); + current_chunk_size_ = chunk.size(); + + inputTensors_[0].name = "input"; + inputTensors_[0].SetExternalData( + input_node_dims_, fastdeploy::FDDataType::FP32, chunk.data()); + inputTensors_[1].name = "sr"; + inputTensors_[1].SetExternalData( + sr_node_dims_, fastdeploy::FDDataType::INT64, sr_.data()); + inputTensors_[2].name = "h"; + inputTensors_[2].SetExternalData( + hc_node_dims_, fastdeploy::FDDataType::FP32, h_.data()); + inputTensors_[3].name = "c"; + inputTensors_[3].SetExternalData( + hc_node_dims_, fastdeploy::FDDataType::FP32, c_.data()); + + if (!Infer(inputTensors_, &outputTensors_)) { + return false; + } + + // Push forward sample index + current_sample_ += current_chunk_size_; + return true; +} + +const Vad::State& Vad::Postprocess() { + // update prob, h, c + outputProb_ = *(float*)outputTensors_[0].Data(); + auto* hn = static_cast(outputTensors_[1].MutableData()); + std::memcpy(h_.data(), hn, h_.size() * sizeof(float)); + auto* cn = static_cast(outputTensors_[2].MutableData()); + std::memcpy(c_.data(), cn, c_.size() * sizeof(float)); + + if (outputProb_ < threshold_ && !triggerd_) { + // 1. Silence +#ifdef PPS_DEBUG + DLOG(INFO) << "{ silence: " << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; +#endif + states_.emplace_back(Vad::State::SIL); + } else if (outputProb_ >= threshold_ && !triggerd_) { + // 2. Start + triggerd_ = true; + speech_start_ = + current_sample_ - current_chunk_size_ - speech_pad_left_samples_; + speech_start_ = std::max(int(speech_start_), 0); + float start_sec = 1.0 * speech_start_ / sample_rate_; + speechStart_.emplace_back(start_sec); +#ifdef PPS_DEBUG + DLOG(INFO) << "{ speech start: " << start_sec + << " s; prob: " << outputProb_ << " }"; +#endif + states_.emplace_back(Vad::State::START); + } else if (outputProb_ >= threshold_ - beam_ && triggerd_) { + // 3. Continue + + if (temp_end_ != 0) { + // speech prob relaxation, speech continues again +#ifdef PPS_DEBUG + DLOG(INFO) + << "{ speech fake end(sil < min_silence_ms) to continue: " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; +#endif + temp_end_ = 0; + } else { + // speech prob relaxation, keep tracking speech +#ifdef PPS_DEBUG + DLOG(INFO) << "{ speech continue: " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; +#endif + } + + states_.emplace_back(Vad::State::SPEECH); + } else if (outputProb_ < threshold_ - beam_ && triggerd_) { + // 4. End + if (temp_end_ == 0) { + temp_end_ = current_sample_; + } + + // check possible speech end + if (current_sample_ - temp_end_ < min_silence_samples_) { + // a. silence < min_slience_samples, continue speaking +#ifdef PPS_DEBUG + DLOG(INFO) << "{ speech fake end(sil < min_silence_ms): " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; +#endif + states_.emplace_back(Vad::State::SIL); + } else { + // b. silence >= min_slience_samples, end speaking + speech_end_ = current_sample_ + speech_pad_right_samples_; + temp_end_ = 0; + triggerd_ = false; + auto end_sec = 1.0 * speech_end_ / sample_rate_; + speechEnd_.emplace_back(end_sec); +#ifdef PPS_DEBUG + DLOG(INFO) << "{ speech end: " << end_sec + << " s; prob: " << outputProb_ << " }"; +#endif + states_.emplace_back(Vad::State::END); + } + } + + return states_.back(); +} + +std::string Vad::ConvertTime(float time_s) const{ + float seconds_tmp, minutes_tmp, hours_tmp; + float seconds; + int minutes, hours; + + // 计算小时 + hours_tmp = time_s / 60 / 60; // 1 + hours = (int)hours_tmp; + + // 计算分钟 + minutes_tmp = time_s / 60; + if (minutes_tmp >= 60) { + minutes = minutes_tmp - 60 * (double)hours; + } + else { + minutes = minutes_tmp; + } + + // 计算秒数 + seconds_tmp = (60 * 60 * hours) + (60 * minutes); + seconds = time_s - seconds_tmp; + + // 输出格式 + std::stringstream ss; + ss << hours << ":" << minutes << ":" << seconds; + + return ss.str(); +} + +int Vad::GetResult(char* result, int max_len, + float removeThreshold, + float expandHeadThreshold, + float expandTailThreshold, + float mergeThreshold) const { + float audioLength = 1.0 * current_sample_ / sample_rate_; + if (speechStart_.empty() && speechEnd_.empty()) { + return {}; + } + if (speechEnd_.size() != speechStart_.size()) { + // set the audio length as the last end + speechEnd_.emplace_back(audioLength); + } + + std::string json = "["; + + for (int i = 0; i < speechStart_.size(); ++i) { + json += "{\"s\":\"" + ConvertTime(speechStart_[i]) + "\",\"e\":\"" + ConvertTime(speechEnd_[i]) + "\"},"; + } + json.pop_back(); + json += "]"; + + if(result != NULL){ + snprintf(result, max_len, "%s", json.c_str()); + } else { + DLOG(INFO) << "result is NULL"; + } + return 0; +} + +std::ostream& operator<<(std::ostream& os, const Vad::State& s) { + switch (s) { + case Vad::State::SIL: + os << "[SIL]"; + break; + case Vad::State::START: + os << "[STA]"; + break; + case Vad::State::SPEECH: + os << "[SPE]"; + break; + case Vad::State::END: + os << "[END]"; + break; + default: + // illegal state + os << "[ILL]"; + break; + } + return os; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/vad/nnet/vad.h b/runtime/engine/vad/nnet/vad.h new file mode 100644 index 00000000..31db78d2 --- /dev/null +++ b/runtime/engine/vad/nnet/vad.h @@ -0,0 +1,157 @@ +// Copyright (c) 2023 Chen Qianhe Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/runtime.h" +#include "vad/frontend/wav.h" + +namespace ppspeech { + +struct VadNnetConf { + // wav + int sr; + int frame_ms; + float threshold; + float beam; + int min_silence_duration_ms; + int speech_pad_left_ms; + int speech_pad_right_ms; + + // model + std::string model_file_path; + std::string param_file_path; + std::string dict_file_path; + int num_cpu_thread; // 1 thred + std::string backend; // ort,lite, etc. +}; + +class Vad : public fastdeploy::FastDeployModel { + public: + enum class State { ILLEGAL = 0, SIL, START, SPEECH, END }; + friend std::ostream& operator<<(std::ostream& os, const Vad::State& s); + + Vad(const std::string& model_file, + const fastdeploy::RuntimeOption& custom_option = + fastdeploy::RuntimeOption()); + + virtual ~Vad() {} + + void Init(); + + void Reset(); + + void SetConfig(const int& sr, + const int& frame_ms, + const float& threshold, + const float& beam, + const int& min_silence_duration_ms, + const int& speech_pad_left_ms, + const int& speech_pad_right_ms); + void SetConfig(const VadNnetConf conf); + + bool ForwardChunk(std::vector& chunk); + + const State& Postprocess(); + + int GetResult(char* result, int max_len, + float removeThreshold = 0.0, + float expandHeadThreshold = 0.0, + float expandTailThreshold = 0, + float mergeThreshold = 0.0) const; + + const std::vector GetStates() const { return states_; } + + int SampleRate() const { return sample_rate_; } + + int FrameMs() const { return frame_ms_; } + int64_t WindowSizeSamples() const { return window_size_samples_; } + + float Threshold() const { return threshold_; } + + int MinSilenceDurationMs() const { + return min_silence_samples_ / sample_rate_; + } + int SpeechPadLeftMs() const { + return speech_pad_left_samples_ / sample_rate_; + } + int SpeechPadRightMs() const { + return speech_pad_right_samples_ / sample_rate_; + } + + int MinSilenceSamples() const { return min_silence_samples_; } + int SpeechPadLeftSamples() const { return speech_pad_left_samples_; } + int SpeechPadRightSamples() const { return speech_pad_right_samples_; } + + std::string ModelName() const override; + + private: + bool Initialize(); + std::string ConvertTime(float time_s) const; + + private: + std::mutex init_lock_; + bool initialized_{false}; + + // input and output + std::vector inputTensors_; + std::vector outputTensors_; + + // model states + bool triggerd_ = false; + unsigned int speech_start_ = 0; + unsigned int speech_end_ = 0; + unsigned int temp_end_ = 0; + unsigned int current_sample_ = 0; + unsigned int current_chunk_size_ = 0; + // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes + float outputProb_; + + std::vector speechStart_; + mutable std::vector speechEnd_; + + std::vector states_; + + /* ======================================================================== + */ + int sample_rate_ = 16000; + int frame_ms_ = 32; // 32, 64, 96 for 16k + float threshold_ = 0.5f; + float beam_ = 0.15f; + + int64_t window_size_samples_; // support 256 512 768 for 8k; 512 1024 1536 + // for 16k. + int sr_per_ms_; // support 8 or 16 + int min_silence_samples_; // sr_per_ms_ * frame_ms_ + int speech_pad_left_samples_{0}; // usually 250ms + int speech_pad_right_samples_{0}; // usually 0 + + /* ======================================================================== + */ + std::vector sr_; + const size_t size_hc_ = 2 * 1 * 64; // It's FIXED. + std::vector h_; + std::vector c_; + + std::vector input_node_dims_; + const std::vector sr_node_dims_ = {1}; + const std::vector hc_node_dims_ = {2, 1, 64}; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/vad/nnet/vad_nnet_main.cc b/runtime/engine/vad/nnet/vad_nnet_main.cc new file mode 100644 index 00000000..f3079b42 --- /dev/null +++ b/runtime/engine/vad/nnet/vad_nnet_main.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "common/base/common.h" +#include "vad/nnet/vad.h" + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout << "Usage: vad_nnet_main path/to/model path/to/audio " + "run_option, " + "e.g ./vad_nnet_main silero_vad.onnx sample.wav" + << std::endl; + return -1; + } + + std::string model_file = argv[1]; + std::string audio_file = argv[2]; + + int sr = 16000; + ppspeech::Vad vad(model_file); + // custom config, but must be set before init + vad.SetConfig(sr, 32, 0.5f, 0.15, 200, 0, 0); + vad.Init(); + + std::vector inputWav; // [0, 1] + wav::WavReader wav_reader = wav::WavReader(audio_file); + assert(wav_reader.sample_rate() == sr); + + + auto num_samples = wav_reader.num_samples(); + inputWav.resize(num_samples); + for (int i = 0; i < num_samples; i++) { + inputWav[i] = wav_reader.data()[i] / 32768; + } + + ppspeech::Timer timer; + int window_size_samples = vad.WindowSizeSamples(); + for (int64_t j = 0; j < num_samples; j += window_size_samples) { + auto start = j; + auto end = start + window_size_samples >= num_samples + ? num_samples + : start + window_size_samples; + auto current_chunk_size = end - start; + + std::vector r{&inputWav[0] + start, &inputWav[0] + end}; + assert(r.size() == static_cast(current_chunk_size)); + + if (!vad.ForwardChunk(r)) { + std::cerr << "Failed to inference while using model:" + << vad.ModelName() << "." << std::endl; + return false; + } + + ppspeech::Vad::State s = vad.Postprocess(); + std::cout << s << " "; + } + std::cout << std::endl; + + std::cout << "RTF=" << timer.Elapsed() / double(num_samples / sr) + << std::endl; + std::cout << "\b\b " << std::endl; + + vad.Reset(); + + return 0; +} diff --git a/speechx/examples/.gitignore b/runtime/examples/.gitignore similarity index 80% rename from speechx/examples/.gitignore rename to runtime/examples/.gitignore index b7075fa5..38290f34 100644 --- a/speechx/examples/.gitignore +++ b/runtime/examples/.gitignore @@ -1,2 +1,3 @@ *.ark +*.scp paddle_asr_model/ diff --git a/speechx/examples/README.md b/runtime/examples/README.md similarity index 100% rename from speechx/examples/README.md rename to runtime/examples/README.md diff --git a/runtime/examples/android/VadJni/.gitignore b/runtime/examples/android/VadJni/.gitignore new file mode 100644 index 00000000..aa724b77 --- /dev/null +++ b/runtime/examples/android/VadJni/.gitignore @@ -0,0 +1,15 @@ +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties diff --git a/runtime/examples/android/VadJni/.idea/.gitignore b/runtime/examples/android/VadJni/.idea/.gitignore new file mode 100644 index 00000000..26d33521 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/runtime/examples/android/VadJni/.idea/.name b/runtime/examples/android/VadJni/.idea/.name new file mode 100644 index 00000000..b5712d1e --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/.name @@ -0,0 +1 @@ +VadJni \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/compiler.xml b/runtime/examples/android/VadJni/.idea/compiler.xml new file mode 100644 index 00000000..fb7f4a8a --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/compiler.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml b/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml new file mode 100644 index 00000000..f26362be --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/gradle.xml b/runtime/examples/android/VadJni/.idea/gradle.xml new file mode 100644 index 00000000..a2d7c213 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/gradle.xml @@ -0,0 +1,19 @@ + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/misc.xml b/runtime/examples/android/VadJni/.idea/misc.xml new file mode 100644 index 00000000..bdd92780 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/misc.xml @@ -0,0 +1,10 @@ + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/vcs.xml b/runtime/examples/android/VadJni/.idea/vcs.xml new file mode 100644 index 00000000..4fce1d86 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/.gitignore b/runtime/examples/android/VadJni/app/.gitignore new file mode 100644 index 00000000..44399f1d --- /dev/null +++ b/runtime/examples/android/VadJni/app/.gitignore @@ -0,0 +1,2 @@ +/build +/cache diff --git a/runtime/examples/android/VadJni/app/build.gradle b/runtime/examples/android/VadJni/app/build.gradle new file mode 100644 index 00000000..f2025a21 --- /dev/null +++ b/runtime/examples/android/VadJni/app/build.gradle @@ -0,0 +1,129 @@ +plugins { + id 'com.android.application' +} + +android { + namespace 'com.baidu.paddlespeech.vadjni' + compileSdk 33 + ndkVersion '23.1.7779620' + + defaultConfig { + applicationId "com.baidu.paddlespeech.vadjni" + minSdk 21 + targetSdk 33 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + externalNativeBuild { + cmake { + arguments '-DANDROID_PLATFORM=android-21', '-DANDROID_STL=c++_shared', "-DANDROID_TOOLCHAIN=clang" + abiFilters 'arm64-v8a' + cppFlags "-std=c++11" + } + } + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + externalNativeBuild { + cmake { + path file('src/main/cpp/CMakeLists.txt') + version '3.22.1' + } + } + buildFeatures { + viewBinding true + } + sourceSets { + main { + jniLibs.srcDirs = ['libs'] + } + } +} + +dependencies { + // Dependency on local binaries + implementation fileTree(dir: 'libs', include: ['*.jar']) + // Dependency on a remote binary + implementation 'androidx.appcompat:appcompat:1.4.1' + implementation 'com.google.android.material:material:1.5.0' + implementation 'androidx.constraintlayout:constraintlayout:2.1.3' + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.3' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' +} + +def CXX_LIB = [ +// [ +// 'src' : 'https://bj.bcebos.com/fastdeploy/dev/android/fastdeploy-android-with-text-0.0.0-shared.tgz', +// 'dest': 'libs', +// 'name': 'fastdeploy-android-latest-shared-dev' +// ] +] + +task downloadAndExtractLibs(type: DefaultTask) { + doFirst { + println "[INFO] Downloading and extracting fastdeploy android c++ lib ..." + } + doLast { + String cachePath = "cache" + if (!file("${cachePath}").exists()) { + mkdir "${cachePath}" + } + + CXX_LIB.eachWithIndex { lib, index -> + + String[] libPaths = lib.src.split("/") + String sdkName = lib.name + String libName = libPaths[libPaths.length - 1] + libName = libName.substring(0, libName.indexOf("tgz") - 1) + String cacheName = cachePath + "/" + "${libName}.tgz" + + String libDir = lib.dest + "/" + libName + String sdkDir = lib.dest + "/" + sdkName + + boolean copyFiles = false + if (!file("${sdkDir}").exists()) { + // Download lib and rename to sdk name later. + if (!file("${cacheName}").exists()) { + println "[INFO] Downloading ${lib.src} -> ${cacheName}" + ant.get(src: lib.src, dest: file("${cacheName}")) + } + copyFiles = true + } + + if (copyFiles) { + println "[INFO] Taring ${cacheName} -> ${libDir}" + copy { from(tarTree("${cacheName}")) into("${lib.dest}") } + if (!libName.equals(sdkName)) { + if (file("${sdkDir}").exists()) { + delete("${sdkDir}") + println "[INFO] Remove old ${sdkDir}" + } + mkdir "${sdkDir}" + println "[INFO] Coping ${libDir} -> ${sdkDir}" + copy { from("${libDir}") into("${sdkDir}") } + delete("${libDir}") + println "[INFO] Removed ${libDir}" + println "[INFO] Update ${sdkDir} done!" + } + } else { + println "[INFO] ${sdkDir} already exists!" + println "[WARN] Please delete ${cacheName} and ${sdkDir} " + + "if you want to UPDATE ${sdkName} c++ lib. Then, rebuild this sdk." + } + } + } +} + +preBuild.dependsOn downloadAndExtractLibs \ No newline at end of file diff --git a/speechx/speechx/kaldi/.gitkeep b/runtime/examples/android/VadJni/app/libs/.gitkeep similarity index 100% rename from speechx/speechx/kaldi/.gitkeep rename to runtime/examples/android/VadJni/app/libs/.gitkeep diff --git a/runtime/examples/android/VadJni/app/proguard-rules.pro b/runtime/examples/android/VadJni/app/proguard-rules.pro new file mode 100644 index 00000000..481bb434 --- /dev/null +++ b/runtime/examples/android/VadJni/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java b/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java new file mode 100644 index 00000000..5c02120b --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java @@ -0,0 +1,26 @@ +package com.baidu.paddlespeech.vadjni; + +import android.content.Context; + +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.junit.Assert.*; + +/** + * Instrumented test, which will execute on an Android device. + * + * @see Testing documentation + */ +@RunWith(AndroidJUnit4.class) +public class ExampleInstrumentedTest { + @Test + public void useAppContext() { + // Context of the app under test. + Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + assertEquals("com.baidu.paddlespeech.vadjni", appContext.getPackageName()); + } +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml b/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml new file mode 100644 index 00000000..d8076922 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/speechx/speechx/third_party/CMakeLists.txt b/runtime/examples/android/VadJni/app/src/main/assets/.gitkeep similarity index 100% rename from speechx/speechx/third_party/CMakeLists.txt rename to runtime/examples/android/VadJni/app/src/main/assets/.gitkeep diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt b/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt new file mode 100644 index 00000000..5eaa053b --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt @@ -0,0 +1,59 @@ +# For more information about using CMake with Android Studio, read the +# documentation: https://d.android.com/studio/projects/add-native-code.html + +# Sets the minimum version of CMake required to build the native library. + +cmake_minimum_required(VERSION 3.22.1) + +# Declares and names the project. + +project("vadjni") + + +set(PPS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/${ANDROID_ABI}) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +# Creates and names a library, sets it as either STATIC +# or SHARED, and provides the relative paths to its source code. +# You can define multiple libraries, and CMake builds them for you. +# Gradle automatically packages shared libraries with your APK. + +add_library( # Sets the name of the library. + vadjni + + # Sets the library as a shared library. + SHARED + + # Provides a relative path to your source file(s). + native-lib.cpp) + +# Searches for a specified prebuilt library and stores the path as a +# variable. Because CMake includes system libraries in the search path by +# default, you only need to specify the name of the public NDK library +# you want to add. CMake verifies that the library exists before +# completing its build. + +find_library( # Sets the name of the path variable. + log-lib + + # Specifies the name of the NDK library that + # you want CMake to locate. + log) + +# Specifies libraries CMake should link to your target library. You +# can link multiple libraries, such as libraries you define in this +# build script, prebuilt third-party libraries, or system libraries. + +message(STATUS "PPS_DIR=${PPS_DIR}") +target_link_libraries( # Specifies the target library. + vadjni + ${PPS_DIR}/libfastdeploy.so + ${PPS_DIR}/libonnxruntime.so + ${PPS_DIR}/libgflags_nothreads.a + ${PPS_DIR}/libbase.a + ${PPS_DIR}/libpps_vad.a + ${PPS_DIR}/libpps_vad_interface.a + # Links the target library to the log library + # included in the NDK. + ${log-lib}) \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp b/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp new file mode 100644 index 00000000..e80ac2e4 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp @@ -0,0 +1,57 @@ + +#include +#include "vad_interface.h" +#include + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_stringFromJNI( + JNIEnv* env, + jobject /* this */) { + std::string hello = "Hello from C++"; + return env->NewStringUTF(hello.c_str()); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_createInstance( + JNIEnv* env, + jobject thiz, + jstring conf_path){ + const char* path = env->GetStringUTFChars(conf_path, JNI_FALSE); + PPSHandle_t handle = PPSVadCreateInstance(path); + + return (jlong)(handle); + return 0; +} + + +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_destroyInstance(JNIEnv *env, jobject thiz, + jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadDestroyInstance(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_reset(JNIEnv *env, jobject thiz, jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadReset(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_chunkSizeSamples(JNIEnv *env, jobject thiz, + jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadChunkSizeSamples(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_feedForward(JNIEnv *env, jobject thiz, + jlong instance, jfloatArray chunk) { + PPSHandle_t handle = (PPSHandle_t)(instance); + jsize num_elms = env->GetArrayLength(chunk); + jfloat* chunk_ptr = env->GetFloatArrayElements(chunk, JNI_FALSE); + return (jint)PPSVadFeedForward(handle, (float*)chunk_ptr, (int)num_elms); +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h b/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h new file mode 100644 index 00000000..5d7ca709 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h @@ -0,0 +1,46 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* PPSHandle_t; + +typedef enum { + PPS_VAD_ILLEGAL = 0, // error + PPS_VAD_SIL, // silence + PPS_VAD_START, // start speech + PPS_VAD_SPEECH, // in speech + PPS_VAD_END, // end speech + PPS_VAD_NUMSTATES, // number of states +} PPSVadState_t; + +PPSHandle_t PPSVadCreateInstance(const char* conf_path); + +int PPSVadDestroyInstance(PPSHandle_t instance); + +int PPSVadReset(PPSHandle_t instance); + +int PPSVadChunkSizeSamples(PPSHandle_t instance); + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element); + +#ifdef __cplusplus +} +#endif // __cplusplus \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java b/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java new file mode 100644 index 00000000..3b463280 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java @@ -0,0 +1,50 @@ +package com.baidu.paddlespeech.vadjni; + +import androidx.appcompat.app.AppCompatActivity; + +import android.os.Bundle; +import android.widget.Button; +import android.widget.TextView; + +import com.baidu.paddlespeech.vadjni.databinding.ActivityMainBinding; + +public class MainActivity extends AppCompatActivity { + + // Used to load the 'vadjni' library on application startup. + static { + System.loadLibrary("vadjni"); + } + + private ActivityMainBinding binding; + private long instance; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + binding = ActivityMainBinding.inflate(getLayoutInflater()); + setContentView(binding.getRoot()); + + // Example of a call to a native method + TextView tv = binding.sampleText; + tv.setText(stringFromJNI()); + + Button lw = binding.loadWav; + } + + /** + * A native method that is implemented by the 'vadjni' native library, + * which is packaged with this application. + */ + public native String stringFromJNI(); + + public static native long createInstance(String config_path); + + public static native int destroyInstance(long instance); + + public static native int reset(long instance); + + public static native int chunkSizeSamples(long instance); + + public static native int feedForward(long instance, float[] chunk); +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 00000000..2b068d11 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml b/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 00000000..07d5da9c --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml b/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000..c9938516 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,28 @@ + + + + + +