diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 3b8c8788..e54ba3a4 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -26,7 +26,6 @@ option(TEST_DEBUG "option for debug" OFF) # Include third party ############################################################################### # #example for include third party -# FetchContent_Declare() # # FetchContent_MakeAvailable was not added until CMake 3.14 # FetchContent_MakeAvailable() # include_directories() @@ -50,20 +49,25 @@ include_directories(${absl_SOURCE_DIR}) #) #FetchContent_MakeAvailable(libsndfile) -# todo boost build -#include(FetchContent) -#FetchContent_Declare( -# Boost -# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.zip -# URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a -#) -#FetchContent_MakeAvailable(Boost) -#include_directories(${Boost_SOURCE_DIR}) - +#boost +set(boost_SOURCE_DIR ${fc_patch}/boost-src) +set(boost_PREFIX_DIR ${fc_patch}/boost-subbuild/boost-prefix) +include(ExternalProject) +ExternalProject_Add( + boost + URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz + URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a + SOURCE_DIR ${boost_SOURCE_DIR} + PREFIX ${boost_PREFIX_DIR} + BUILD_IN_SOURCE 1 + CONFIGURE_COMMAND ./bootstrap.sh + BUILD_COMMAND ./b2 + INSTALL_COMMAND "" +) +link_directories(${boost_SOURCE_DIR}/stage/lib) +include_directories(${boost_SOURCE_DIR}) -set(BOOST_ROOT ${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0) -include_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0) -link_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0/stage/lib) +set(BOOST_ROOT ${boost_SOURCE_DIR}) include(FetchContent) FetchContent_Declare( kenlm @@ -71,9 +75,10 @@ FetchContent_Declare( GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a" ) FetchContent_MakeAvailable(kenlm) -add_dependencies(kenlm Boost) +add_dependencies(kenlm boost) include_directories(${kenlm_SOURCE_DIR}) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") # gflags FetchContent_Declare( gflags @@ -106,19 +111,34 @@ set(openfst_PREFIX_DIR ${fc_patch}/openfst-subbuild/openfst-populate-prefix) ExternalProject_Add(openfst URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6 +# PREFIX ${openfst_PREFIX_DIR} SOURCE_DIR ${openfst_SOURCE_DIR} BINARY_DIR ${openfst_BINARY_DIR} CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR} "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" "LIBS=-lgflags_nothreads -lglog -lpthread" + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} BUILD_COMMAND make -j 4 ) add_dependencies(openfst gflags glog) link_directories(${openfst_PREFIX_DIR}/lib) include_directories(${openfst_PREFIX_DIR}/include) -set(PADDLE_LIB ${fc_patch}/paddle-lib/paddle_inference) +# paddle lib +set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib) +set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix) +ExternalProject_Add(paddle + URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz + URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873 + PREFIX ${paddle_PREFIX_DIR} + SOURCE_DIR ${paddle_SOURCE_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" +) + +set(PADDLE_LIB ${fc_patch}/paddle-lib) include_directories("${PADDLE_LIB}/paddle/include") set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") @@ -133,6 +153,23 @@ link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib") link_directories("${PADDLE_LIB}/paddle/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib") + +##paddle with mkl +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml") +include_directories("${MATH_LIB_PATH}/include") +set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) +set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn") +include_directories("${MKLDNN_PATH}/include") +set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) +set(EXTERNAL_LIB "-lrt -ldl -lpthread") +set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) +set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags protobuf xxhash cryptopp + ${EXTERNAL_LIB}) add_subdirectory(speechx) diff --git a/speechx/patch/CPPLINT.cfg b/speechx/patch/CPPLINT.cfg new file mode 100644 index 00000000..51ff339c --- /dev/null +++ b/speechx/patch/CPPLINT.cfg @@ -0,0 +1 @@ +exclude_files=.* diff --git a/speechx/patch/openfst/src/include/fst/flags.h b/speechx/patch/openfst/src/include/fst/flags.h new file mode 100644 index 00000000..b5ec8ff7 --- /dev/null +++ b/speechx/patch/openfst/src/include/fst/flags.h @@ -0,0 +1,228 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Google-style flag handling declarations and inline definitions. + +#ifndef FST_LIB_FLAGS_H_ +#define FST_LIB_FLAGS_H_ + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" + +using std::string; + +// FLAGS USAGE: +// +// Definition example: +// +// DEFINE_int32(length, 0, "length"); +// +// This defines variable FLAGS_length, initialized to 0. +// +// Declaration example: +// +// DECLARE_int32(length); +// +// SET_FLAGS() can be used to set flags from the command line +// using, for example, '--length=2'. +// +// ShowUsage() can be used to print out command and flag usage. + +// #define DECLARE_bool(name) extern bool FLAGS_ ## name +// #define DECLARE_string(name) extern string FLAGS_ ## name +// #define DECLARE_int32(name) extern int32 FLAGS_ ## name +// #define DECLARE_int64(name) extern int64 FLAGS_ ## name +// #define DECLARE_double(name) extern double FLAGS_ ## name + +template +struct FlagDescription { + FlagDescription(T *addr, const char *doc, const char *type, + const char *file, const T val) + : address(addr), + doc_string(doc), + type_name(type), + file_name(file), + default_value(val) {} + + T *address; + const char *doc_string; + const char *type_name; + const char *file_name; + const T default_value; +}; + +template +class FlagRegister { + public: + static FlagRegister *GetRegister() { + static auto reg = new FlagRegister; + return reg; + } + + const FlagDescription &GetFlagDescription(const string &name) const { + fst::MutexLock l(&flag_lock_); + auto it = flag_table_.find(name); + return it != flag_table_.end() ? it->second : 0; + } + + void SetDescription(const string &name, + const FlagDescription &desc) { + fst::MutexLock l(&flag_lock_); + flag_table_.insert(make_pair(name, desc)); + } + + bool SetFlag(const string &val, bool *address) const { + if (val == "true" || val == "1" || val.empty()) { + *address = true; + return true; + } else if (val == "false" || val == "0") { + *address = false; + return true; + } + else { + return false; + } + } + + bool SetFlag(const string &val, string *address) const { + *address = val; + return true; + } + + bool SetFlag(const string &val, int32 *address) const { + char *p = 0; + *address = strtol(val.c_str(), &p, 0); + return !val.empty() && *p == '\0'; + } + + bool SetFlag(const string &val, int64 *address) const { + char *p = 0; + *address = strtoll(val.c_str(), &p, 0); + return !val.empty() && *p == '\0'; + } + + bool SetFlag(const string &val, double *address) const { + char *p = 0; + *address = strtod(val.c_str(), &p); + return !val.empty() && *p == '\0'; + } + + bool SetFlag(const string &arg, const string &val) const { + for (typename std::map< string, FlagDescription >::const_iterator it = + flag_table_.begin(); + it != flag_table_.end(); + ++it) { + const string &name = it->first; + const FlagDescription &desc = it->second; + if (arg == name) + return SetFlag(val, desc.address); + } + return false; + } + + void GetUsage(std::set> *usage_set) const { + for (auto it = flag_table_.begin(); it != flag_table_.end(); ++it) { + const string &name = it->first; + const FlagDescription &desc = it->second; + string usage = " --" + name; + usage += ": type = "; + usage += desc.type_name; + usage += ", default = "; + usage += GetDefault(desc.default_value) + "\n "; + usage += desc.doc_string; + usage_set->insert(make_pair(desc.file_name, usage)); + } + } + + private: + string GetDefault(bool default_value) const { + return default_value ? "true" : "false"; + } + + string GetDefault(const string &default_value) const { + return "\"" + default_value + "\""; + } + + template + string GetDefault(const V &default_value) const { + std::ostringstream strm; + strm << default_value; + return strm.str(); + } + + mutable fst::Mutex flag_lock_; // Multithreading lock. + std::map> flag_table_; +}; + +template +class FlagRegisterer { + public: + FlagRegisterer(const string &name, const FlagDescription &desc) { + auto registr = FlagRegister::GetRegister(); + registr->SetDescription(name, desc); + } + + private: + FlagRegisterer(const FlagRegisterer &) = delete; + FlagRegisterer &operator=(const FlagRegisterer &) = delete; +}; + + +#define DEFINE_VAR(type, name, value, doc) \ + type FLAGS_ ## name = value; \ + static FlagRegisterer \ + name ## _flags_registerer(#name, FlagDescription(&FLAGS_ ## name, \ + doc, \ + #type, \ + __FILE__, \ + value)) + +// #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc) +// #define DEFINE_string(name, value, doc) \ +// DEFINE_VAR(string, name, value, doc) +// #define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc) +// #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc) +// #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc) + + +// Temporary directory. +DECLARE_string(tmpdir); + +void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, + const char *src = ""); + +#define SET_FLAGS(usage, argc, argv, rmflags) \ +gflags::ParseCommandLineFlags(argc, argv, true) +// SetFlags(usage, argc, argv, rmflags, __FILE__) + +// Deprecated; for backward compatibility. +inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) { + return SetFlags(usage, argc, argv, rmflags); +} + +void ShowUsage(bool long_usage = true); + +#endif // FST_LIB_FLAGS_H_ diff --git a/speechx/patch/openfst/src/include/fst/log.h b/speechx/patch/openfst/src/include/fst/log.h new file mode 100644 index 00000000..bf041c58 --- /dev/null +++ b/speechx/patch/openfst/src/include/fst/log.h @@ -0,0 +1,82 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Google-style logging declarations and inline definitions. + +#ifndef FST_LIB_LOG_H_ +#define FST_LIB_LOG_H_ + +#include +#include +#include + +#include +#include + +using std::string; + +DECLARE_int32(v); + +class LogMessage { + public: + LogMessage(const string &type) : fatal_(type == "FATAL") { + std::cerr << type << ": "; + } + ~LogMessage() { + std::cerr << std::endl; + if(fatal_) + exit(1); + } + std::ostream &stream() { return std::cerr; } + + private: + bool fatal_; +}; + +// #define LOG(type) LogMessage(#type).stream() +// #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO) + +// Checks +inline void FstCheck(bool x, const char* expr, + const char *file, int line) { + if (!x) { + LOG(FATAL) << "Check failed: \"" << expr + << "\" file: " << file + << " line: " << line; + } +} + +// #define CHECK(x) FstCheck(static_cast(x), #x, __FILE__, __LINE__) +// #define CHECK_EQ(x, y) CHECK((x) == (y)) +// #define CHECK_LT(x, y) CHECK((x) < (y)) +// #define CHECK_GT(x, y) CHECK((x) > (y)) +// #define CHECK_LE(x, y) CHECK((x) <= (y)) +// #define CHECK_GE(x, y) CHECK((x) >= (y)) +// #define CHECK_NE(x, y) CHECK((x) != (y)) + +// Debug checks +// #define DCHECK(x) assert(x) +// #define DCHECK_EQ(x, y) DCHECK((x) == (y)) +// #define DCHECK_LT(x, y) DCHECK((x) < (y)) +// #define DCHECK_GT(x, y) DCHECK((x) > (y)) +// #define DCHECK_LE(x, y) DCHECK((x) <= (y)) +// #define DCHECK_GE(x, y) DCHECK((x) >= (y)) +// #define DCHECK_NE(x, y) DCHECK((x) != (y)) + + +// Ports +#define ATTRIBUTE_DEPRECATED __attribute__((deprecated)) + +#endif // FST_LIB_LOG_H_ diff --git a/speechx/patch/openfst/src/lib/flags.cc b/speechx/patch/openfst/src/lib/flags.cc new file mode 100644 index 00000000..95f7e2e9 --- /dev/null +++ b/speechx/patch/openfst/src/lib/flags.cc @@ -0,0 +1,166 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Google-style flag handling definitions. + +#include + +#if _MSC_VER +#include +#include +#endif + +#include +#include + +static const char *private_tmpdir = getenv("TMPDIR"); + +// DEFINE_int32(v, 0, "verbosity level"); +// DEFINE_bool(help, false, "show usage information"); +// DEFINE_bool(helpshort, false, "show brief usage information"); +#ifndef _MSC_VER +DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp", + "temporary directory"); +#else +DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"), + "temporary directory"); +#endif // !_MSC_VER + +using namespace std; + +static string flag_usage; +static string prog_src; + +// Sets prog_src to src. +static void SetProgSrc(const char *src) { + prog_src = src; +#if _MSC_VER + // This common code is invoked by all FST binaries, and only by them. Switch + // stdin and stdout into "binary" mode, so that 0x0A won't be translated into + // a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are + // already using ios::binary where binary files are read or written. + // Kudos to @daanzu for the suggested fix. + // https://github.com/kkm000/openfst/issues/20 + // https://github.com/kkm000/openfst/pull/23 + // https://github.com/kkm000/openfst/pull/32 + _setmode(_fileno(stdin), O_BINARY); + _setmode(_fileno(stdout), O_BINARY); +#endif + // Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags() + // is called in fstx-main.cc, which results in a filename mismatch in + // ShowUsageRestrict() below. + static constexpr char kMainSuffix[] = "-main.cc"; + const int prefix_length = prog_src.size() - strlen(kMainSuffix); + if (prefix_length > 0 && prog_src.substr(prefix_length) == kMainSuffix) { + prog_src.erase(prefix_length, strlen("-main")); + } +} + +void SetFlags(const char *usage, int *argc, char ***argv, + bool remove_flags, const char *src) { + flag_usage = usage; + SetProgSrc(src); + + int index = 1; + for (; index < *argc; ++index) { + string argval = (*argv)[index]; + if (argval[0] != '-' || argval == "-") break; + while (argval[0] == '-') argval = argval.substr(1); // Removes initial '-'. + string arg = argval; + string val = ""; + // Splits argval (arg=val) into arg and val. + auto pos = argval.find("="); + if (pos != string::npos) { + arg = argval.substr(0, pos); + val = argval.substr(pos + 1); + } + auto bool_register = FlagRegister::GetRegister(); + if (bool_register->SetFlag(arg, val)) + continue; + auto string_register = FlagRegister::GetRegister(); + if (string_register->SetFlag(arg, val)) + continue; + auto int32_register = FlagRegister::GetRegister(); + if (int32_register->SetFlag(arg, val)) + continue; + auto int64_register = FlagRegister::GetRegister(); + if (int64_register->SetFlag(arg, val)) + continue; + auto double_register = FlagRegister::GetRegister(); + if (double_register->SetFlag(arg, val)) + continue; + LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index]; + } + if (remove_flags) { + for (auto i = 0; i < *argc - index; ++i) { + (*argv)[i + 1] = (*argv)[i + index]; + } + *argc -= index - 1; + } + // if (FLAGS_help) { + // ShowUsage(true); + // exit(1); + // } + // if (FLAGS_helpshort) { + // ShowUsage(false); + // exit(1); + // } +} + +// If flag is defined in file 'src' and 'in_src' true or is not +// defined in file 'src' and 'in_src' is false, then print usage. +static void +ShowUsageRestrict(const std::set> &usage_set, + const string &src, bool in_src, bool show_file) { + string old_file; + bool file_out = false; + bool usage_out = false; + for (const auto &pair : usage_set) { + const auto &file = pair.first; + const auto &usage = pair.second; + bool match = file == src; + if ((match && !in_src) || (!match && in_src)) continue; + if (file != old_file) { + if (show_file) { + if (file_out) cout << "\n"; + cout << "Flags from: " << file << "\n"; + file_out = true; + } + old_file = file; + } + cout << usage << "\n"; + usage_out = true; + } + if (usage_out) cout << "\n"; +} + +void ShowUsage(bool long_usage) { + std::set> usage_set; + cout << flag_usage << "\n"; + auto bool_register = FlagRegister::GetRegister(); + bool_register->GetUsage(&usage_set); + auto string_register = FlagRegister::GetRegister(); + string_register->GetUsage(&usage_set); + auto int32_register = FlagRegister::GetRegister(); + int32_register->GetUsage(&usage_set); + auto int64_register = FlagRegister::GetRegister(); + int64_register->GetUsage(&usage_set); + auto double_register = FlagRegister::GetRegister(); + double_register->GetUsage(&usage_set); + if (!prog_src.empty()) { + cout << "PROGRAM FLAGS:\n\n"; + ShowUsageRestrict(usage_set, prog_src, true, false); + } + if (!long_usage) return; + if (!prog_src.empty()) cout << "LIBRARY FLAGS:\n\n"; + ShowUsageRestrict(usage_set, prog_src, false, true); +} diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index bdf82146..55a782ac 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -2,10 +2,6 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project(speechx LANGUAGES CXX) -link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") - include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/kaldi @@ -37,10 +33,13 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder add_subdirectory(decoder) add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) -target_link_libraries(mfcc-test kaldi-mfcc) +target_link_libraries(mfcc-test kaldi-mfcc ${MATH_LIB}) add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc) target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog) -#add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc) -#target_link_libraries(offline_decoder_main nnet decoder gflags glog) +add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc) +target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) + +add_executable(model_test codelab/nnet_test/model_test.cc) +target_link_libraries(model_test PUBLIC nnet gflags ${DEPS}) diff --git a/speechx/speechx/base/flags.h b/speechx/speechx/base/flags.h index 2808fac3..41df0d45 100644 --- a/speechx/speechx/base/flags.h +++ b/speechx/speechx/base/flags.h @@ -14,4 +14,4 @@ #pragma once -#include "gflags/gflags.h" +#include "fst/flags.h" diff --git a/speechx/speechx/base/log.h b/speechx/speechx/base/log.h index d1b7b169..c613b98c 100644 --- a/speechx/speechx/base/log.h +++ b/speechx/speechx/base/log.h @@ -14,4 +14,4 @@ #pragma once -#include "glog/logging.h" +#include "fst/log.h" diff --git a/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc b/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc index 1d7b09df..138f5eeb 100644 --- a/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc +++ b/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc @@ -4,16 +4,20 @@ #include "kaldi/util/table-types.h" #include "base/log.h" #include "base/flags.h" +#include "nnet/paddle_nnet.h" +#include "nnet/decodable.h" DEFINE_string(feature_respecifier, "", "test nnet prob"); using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; -void SplitFeature(kaldi::Matrix feature, - int32 chunk_size, - std::vector> feature_chunks) { +//void SplitFeature(kaldi::Matrix feature, +// int32 chunk_size, +// std::vector* feature_chunks) { -} +//} int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -24,31 +28,26 @@ int main(int argc, char* argv[]) { // test nnet_output --> decoder result int32 num_done = 0, num_err = 0; - CTCBeamSearchOptions opts; - CTCBeamSearch decoder(opts); + ppspeech::CTCBeamSearchOptions opts; + ppspeech::CTCBeamSearch decoder(opts); - ModelOptions model_opts; - std::shared_ptr nnet(new PaddleNnet(model_opts)); + ppspeech::ModelOptions model_opts; + std::shared_ptr nnet(new ppspeech::PaddleNnet(model_opts)); - Decodable decodable(); - decodable.SetNnet(nnet); + std::shared_ptr decodable(new ppspeech::Decodable(nnet)); - int32 chunk_size = 0; + //int32 chunk_size = 35; + decoder.InitDecoder(); for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); const kaldi::Matrix feature = feature_reader.Value(); - vector> feature_chunks; - SplitFeature(feature, chunk_size, &feature_chunks); - for (auto feature_chunk : feature_chunks) { - decodable.FeedFeatures(feature_chunk); - decoder.InitDecoder(); - decoder.AdvanceDecode(decodable, chunk_size); - } - decodable.InputFinished(); + decodable->FeedFeatures(feature); + decoder.AdvanceDecode(decodable, 8); + decodable->InputFinished(); std::string result; result = decoder.GetFinalBestPath(); KALDI_LOG << " the result of " << utt << " is " << result; - decodable.Reset(); + decodable->Reset(); ++num_done; } diff --git a/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc index 00162abe..38a1a0b3 100644 --- a/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc +++ b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc @@ -11,6 +11,10 @@ DEFINE_string(wav_rspecifier, "", "test wav path"); DEFINE_string(feature_wspecifier, "", "test wav ark"); +std::vector mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007}; +std::vector variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804}; +int count_ = 912592; + int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -21,20 +25,62 @@ int main(int argc, char* argv[]) { // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn int32 num_done = 0, num_err = 0; ppspeech::LinearSpectrogramOptions opt; + opt.frame_opts.frame_length_ms = 20; + opt.frame_opts.frame_shift_ms = 10; ppspeech::DecibelNormalizerOptions db_norm_opt; std::unique_ptr base_feature_extractor( new ppspeech::DecibelNormalizer(db_norm_opt)); ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor)); + float streaming_chunk = 0.36; + int sample_rate = 16000; + int chunk_sample_size = streaming_chunk * sample_rate; + + LOG(INFO) << mean_.size(); + for (size_t i = 0; i < mean_.size(); i++) { + mean_[i] /= count_; + variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i]; + if (variance_[i] < 1.0e-20) { + variance_[i] = 1.0e-20; + } + variance_[i] = 1.0 / std::sqrt(variance_[i]); + } + for (; !wav_reader.Done(); wav_reader.Next()) { std::string utt = wav_reader.Key(); const kaldi::WaveData &wave_data = wav_reader.Value(); int32 this_channel = 0; kaldi::SubVector waveform(wave_data.Data(), this_channel); - kaldi::Matrix features; - linear_spectrogram.AcceptWaveform(waveform); - linear_spectrogram.ReadFeats(&features); + int tot_samples = waveform.Dim(); + int sample_offset = 0; + std::vector> feats; + int feature_rows = 0; + while (sample_offset < tot_samples) { + int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset); + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + kaldi::Matrix features; + linear_spectrogram.AcceptWaveform(wav_chunk); + linear_spectrogram.ReadFeats(&features); + + feats.push_back(features); + sample_offset += cur_chunk_size; + feature_rows += features.NumRows(); + } + + int cur_idx = 0; + kaldi::Matrix features(feature_rows, feats[0].NumCols()); + for (auto feat : feats) { + for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) { + for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) { + features(cur_idx, col_idx) = (feat(row_idx, col_idx) - mean_[col_idx])*variance_[col_idx]; + } + ++cur_idx; + } + } feat_writer.Write(utt, features); if (num_done % 50 == 0 && num_done != 0) diff --git a/speechx/speechx/codelab/nnet_test/model_test.cc b/speechx/speechx/codelab/nnet_test/model_test.cc new file mode 100644 index 00000000..ce1e7fff --- /dev/null +++ b/speechx/speechx/codelab/nnet_test/model_test.cc @@ -0,0 +1,134 @@ +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include +#include +#include +#include + +void produce_data(std::vector>* data); +void model_forward_test(); + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + model_forward_test(); + return 0; +} + +void model_forward_test() { + std::cout << "1. read the data" << std::endl; + std::vector> feats; + produce_data(&feats); + + std::cout << "2. load the model" << std::endl;; + std::string model_graph = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"; + std::string model_params = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"; + paddle_infer::Config config; + config.SetModel(model_graph, model_params); + config.SwitchIrOptim(false); + config.DisableFCPadding(); + auto predictor = paddle_infer::CreatePredictor(config); + + std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl; + std::vector paddle_input_feature_matrix; + for(const auto& item : feats) { + paddle_input_feature_matrix.insert(paddle_input_feature_matrix.end(), item.begin(), item.end()); + } + + std::cout << "4. fead the data to model" << std::endl; + int row = feats.size(); + int col = feats[0].size(); + std::vector input_names = predictor->GetInputNames(); + std::vector output_names = predictor->GetOutputNames(); + + std::unique_ptr input_tensor = + predictor->GetInputHandle(input_names[0]); + std::vector INPUT_SHAPE = {1, row, col}; + input_tensor->Reshape(INPUT_SHAPE); + input_tensor->CopyFromCpu(paddle_input_feature_matrix.data()); + + std::unique_ptr input_len = predictor->GetInputHandle(input_names[1]); + std::vector input_len_size = {1}; + input_len->Reshape(input_len_size); + std::vector audio_len; + audio_len.push_back(row); + input_len->CopyFromCpu(audio_len.data()); + + std::unique_ptr chunk_state_h_box = predictor->GetInputHandle(input_names[2]); + std::vector chunk_state_h_box_shape = {3, 1, 1024}; + chunk_state_h_box->Reshape(chunk_state_h_box_shape); + int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(), + 1, std::multiplies()); + std::vector chunk_state_h_box_data(chunk_state_h_box_size, 0.0f); + chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data()); + + std::unique_ptr chunk_state_c_box = predictor->GetInputHandle(input_names[3]); + std::vector chunk_state_c_box_shape = {3, 1, 1024}; + chunk_state_c_box->Reshape(chunk_state_c_box_shape); + int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(), + 1, std::multiplies()); + std::vector chunk_state_c_box_data(chunk_state_c_box_size, 0.0f); + chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data()); + + bool success = predictor->Run(); + + std::unique_ptr h_out = predictor->GetOutputHandle(output_names[2]); + std::vector h_out_shape = h_out->shape(); + int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(), + 1, std::multiplies()); + std::vector h_out_data(h_out_size); + h_out->CopyToCpu(h_out_data.data()); + + std::unique_ptr c_out = predictor->GetOutputHandle(output_names[3]); + std::vector c_out_shape = c_out->shape(); + int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(), + 1, std::multiplies()); + std::vector c_out_data(c_out_size); + c_out->CopyToCpu(c_out_data.data()); + + std::unique_ptr output_tensor = + predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_tensor->shape(); + std::vector output_probs; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), + 1, std::multiplies()); + output_probs.resize(output_size); + output_tensor->CopyToCpu(output_probs.data()); + row = output_shape[1]; + col = output_shape[2]; + + std::vector> probs; + probs.reserve(row); + for (int i = 0; i < row; i++) { + probs.push_back(std::vector()); + probs.back().reserve(col); + + for (int j = 0; j < col; j++) { + probs.back().push_back(output_probs[i * col + j]); + } + } + + std::vector> log_feat = probs; + std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl; + for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) { + for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) { + std::cout << log_feat[row_idx][col_idx] << " "; + } + std::cout << std::endl; + } +} + +void produce_data(std::vector>* data) { + int chunk_size = 35; + int col_size = 161; + data->reserve(chunk_size); + data->back().reserve(col_size); + for (int row = 0; row < chunk_size; ++row) { + data->push_back(std::vector()); + for (int col_idx = 0; col_idx < col_size; ++col_idx) { + data->back().push_back(0.201); + } + } +} diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 8885dca9..7cd281b6 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -1,10 +1,10 @@ project(decoder) include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) -add_library(decoder +add_library(decoder STATIC ctc_beam_search_decoder.cc ctc_decoders/decoder_utils.cpp ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp ) -target_link_libraries(decoder kenlm) \ No newline at end of file +target_link_libraries(decoder PUBLIC kenlm utils fst) \ No newline at end of file diff --git a/speechx/speechx/frontend/CMakeLists.txt b/speechx/speechx/frontend/CMakeLists.txt index 48a5267b..da81a481 100644 --- a/speechx/speechx/frontend/CMakeLists.txt +++ b/speechx/speechx/frontend/CMakeLists.txt @@ -1,8 +1,8 @@ project(frontend) -add_library(frontend +add_library(frontend STATIC normalizer.cc linear_spectrogram.cc ) -target_link_libraries(frontend kaldi-matrix) \ No newline at end of file +target_link_libraries(frontend PUBLIC kaldi-matrix) \ No newline at end of file diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc index a23b4494..8c20985d 100644 --- a/speechx/speechx/frontend/linear_spectrogram.cc +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -47,6 +47,7 @@ void CopyStdVector2Vector_(const vector& input, LinearSpectrogram::LinearSpectrogram( const LinearSpectrogramOptions& opts, std::unique_ptr base_extractor) { + opts_ = opts; base_extractor_ = std::move(base_extractor); int32 window_size = opts.frame_opts.WindowSize(); int32 window_shift = opts.frame_opts.WindowShift(); @@ -105,7 +106,7 @@ void LinearSpectrogram::ReadFeats(Matrix* feats) { Compute(feats_vec, result); feats->Resize(result.size(), result[0].size()); for (int row_idx = 0; row_idx < result.size(); ++row_idx) { - for (int col_idx = 0; col_idx < result.size(); ++col_idx) { + for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) { (*feats)(row_idx, col_idx) = result[row_idx][col_idx]; } } @@ -133,7 +134,7 @@ bool LinearSpectrogram::Compute(const vector& wave, const int& sample_rate = opts_.frame_opts.samp_freq; const int& frame_shift = opts_.frame_opts.WindowShift(); const int& fft_points = fft_points_; - const float scale = hanning_window_energy_ * frame_shift; + const float scale = hanning_window_energy_ * sample_rate; if (num_samples < frame_length) { return true; @@ -153,10 +154,7 @@ bool LinearSpectrogram::Compute(const vector& wave, fft_img.clear(); fft_real.clear(); v.assign(data.begin(), data.end()); - if (NumpyFft(&v, &fft_real, &fft_img)) { - LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data"; - return false; - } + NumpyFft(&v, &fft_real, &fft_img); feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz for (int j = 0; j < (fft_points / 2 + 1); ++j) { diff --git a/speechx/speechx/frontend/normalizer.h b/speechx/speechx/frontend/normalizer.h index 3bf36cfc..78670bb4 100644 --- a/speechx/speechx/frontend/normalizer.h +++ b/speechx/speechx/frontend/normalizer.h @@ -29,7 +29,7 @@ class DecibelNormalizer : public FeatureExtractorInterface { explicit DecibelNormalizer(const DecibelNormalizerOptions& opts); virtual void AcceptWaveform(const kaldi::VectorBase& input); virtual void Read(kaldi::VectorBase* feat); - virtual size_t Dim() const { return 0; } + virtual size_t Dim() const { return dim_; } bool Compute(const kaldi::VectorBase& input, kaldi::VectorBase* feat) const; private: diff --git a/speechx/speechx/kaldi/matrix/BUILD b/speechx/speechx/kaldi/matrix/BUILD deleted file mode 100644 index cefac6fc..00000000 --- a/speechx/speechx/kaldi/matrix/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2020 PeachLab. All Rights Reserved. -# Author : goat.zhou@qq.com (Yang Zhou) - -package(default_visibility = ["//visibility:public"]) - -cc_library( - name = 'kaldi-matrix', - srcs = [ - 'compressed-matrix.cc', - 'kaldi-matrix.cc', - 'kaldi-vector.cc', - 'matrix-functions.cc', - 'optimization.cc', - 'packed-matrix.cc', - 'qr.cc', - 'sparse-matrix.cc', - 'sp-matrix.cc', - 'srfft.cc', - 'tp-matrix.cc', - ], - hdrs = glob(["*.h"]), - deps = [ - '//base:kaldi-base', - '//common/third_party/openblas:openblas', - ], - linkopts=['-lgfortran'], -) - -cc_binary( - name = 'matrix-lib-test', - srcs = [ - 'matrix-lib-test.cc', - ], - deps = [ - ':kaldi-matrix', - '//util:kaldi-util', - ], -) - diff --git a/speechx/speechx/kaldi/matrix/CMakeLists.txt b/speechx/speechx/kaldi/matrix/CMakeLists.txt index a4dbde2e..64f44864 100644 --- a/speechx/speechx/kaldi/matrix/CMakeLists.txt +++ b/speechx/speechx/kaldi/matrix/CMakeLists.txt @@ -13,4 +13,4 @@ srfft.cc tp-matrix.cc ) -target_link_libraries(kaldi-matrix gfortran kaldi-base libopenblas.a) +target_link_libraries(kaldi-matrix gfortran kaldi-base ${MATH_LIB}) diff --git a/speechx/speechx/kaldi/matrix/kaldi-blas.h b/speechx/speechx/kaldi/matrix/kaldi-blas.h index b08d8c51..cf4ad9df 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-blas.h +++ b/speechx/speechx/kaldi/matrix/kaldi-blas.h @@ -42,7 +42,7 @@ -#define HAVE_OPENBLAS +#define HAVE_MKL #if (defined(HAVE_CLAPACK) && (defined(HAVE_ATLAS) || defined(HAVE_MKL))) \ || (defined(HAVE_ATLAS) && defined(HAVE_MKL)) diff --git a/speechx/speechx/nnet/CMakeLists.txt b/speechx/speechx/nnet/CMakeLists.txt index 4d336b86..cee881de 100644 --- a/speechx/speechx/nnet/CMakeLists.txt +++ b/speechx/speechx/nnet/CMakeLists.txt @@ -1,2 +1,7 @@ -aux_source_directory(. DIR_LIB_SRCS) -add_library(nnet STATIC ${DIR_LIB_SRCS}) +project(nnet) + +add_library(nnet STATIC + decodable.cc + paddle_nnet.cc +) +target_link_libraries(nnet absl::strings) \ No newline at end of file diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/nnet/decodable-itf.h index 93f7db76..5f641b6c 100644 --- a/speechx/speechx/nnet/decodable-itf.h +++ b/speechx/speechx/nnet/decodable-itf.h @@ -114,7 +114,7 @@ class DecodableInterface { /// this is for compatibility with OpenFst). virtual int32 NumIndices() const = 0; - virtual std::vector FrameLogLikelihood(int32 frame); + virtual std::vector FrameLogLikelihood(int32 frame) = 0; virtual ~DecodableInterface() {} }; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 984f3ad3..45486bc0 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -38,6 +38,15 @@ void Decodable::FeedFeatures(const Matrix& features) { return ; } +std::vector Decodable::FrameLogLikelihood(int32 frame) { + std::vector result; + result.reserve(nnet_cache_.NumCols()); + for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { + result[idx] = nnet_cache_(frame, idx); + } + return result; +} + void Decodable::Reset() { // frontend_.Reset(); nnet_->Reset(); diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 6f06d69a..eb9cddb9 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -15,9 +15,9 @@ class Decodable : public kaldi::DecodableInterface { virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual bool IsLastFrame(int32 frame) const; virtual int32 NumIndices() const; + virtual std::vector FrameLogLikelihood(int32 frame); void Acceptlikelihood(const kaldi::Matrix& likelihood); // remove later void FeedFeatures(const kaldi::Matrix& feature); // only for test, todo remove later - std::vector FrameLogLikelihood(int32 frame); void Reset(); void InputFinished() { finished_ = true; } private: diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h index 577662f3..cdc0a6f2 100644 --- a/speechx/speechx/nnet/nnet_interface.h +++ b/speechx/speechx/nnet/nnet_interface.h @@ -9,11 +9,11 @@ namespace ppspeech { class NnetInterface { public: - virtual ~NnetInterface() {} virtual void FeedForward(const kaldi::Matrix& features, - kaldi::Matrix* inferences); - virtual void Reset(); + kaldi::Matrix* inferences)= 0; + virtual void Reset() = 0; + virtual ~NnetInterface() {} }; -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/nnet/paddle_nnet.cc b/speechx/speechx/nnet/paddle_nnet.cc index 61690872..c45bb75e 100644 --- a/speechx/speechx/nnet/paddle_nnet.cc +++ b/speechx/speechx/nnet/paddle_nnet.cc @@ -10,14 +10,16 @@ using kaldi::Matrix; void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { std::vector cache_names; - cache_names = absl::StrSplit(opts.cache_names, ", "); + cache_names = absl::StrSplit(opts.cache_names, ","); std::vector cache_shapes; - cache_shapes = absl::StrSplit(opts.cache_shape, ", "); + cache_shapes = absl::StrSplit(opts.cache_shape, ","); assert(cache_shapes.size() == cache_names.size()); - + + cache_encouts_.clear(); + cache_names_idx_.clear(); for (size_t i = 0; i < cache_shapes.size(); i++) { std::vector tmp_shape; - tmp_shape = absl::StrSplit(cache_shapes[i], "- "); + tmp_shape = absl::StrSplit(cache_shapes[i], "-"); std::vector cur_shape; std::transform(tmp_shape.begin(), tmp_shape.end(), std::back_inserter(cur_shape), @@ -30,14 +32,14 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { } } -PaddleNnet::PaddleNnet(const ModelOptions& opts) { +PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) { paddle_infer::Config config; config.SetModel(opts.model_path, opts.params_path); if (opts.use_gpu) { config.EnableUseGpu(500, 0); } config.SwitchIrOptim(opts.switch_ir_optim); - if (opts.enable_fc_padding) { + if (opts.enable_fc_padding == false) { config.DisableFCPadding(); } if (opts.enable_profile) { @@ -54,8 +56,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) { LOG(INFO) << "start to check the predictor input and output names"; LOG(INFO) << "input names: " << opts.input_names; LOG(INFO) << "output names: " << opts.output_names; - vector input_names_vec = absl::StrSplit(opts.input_names, ", "); - vector output_names_vec = absl::StrSplit(opts.output_names, ", "); + vector input_names_vec = absl::StrSplit(opts.input_names, ","); + vector output_names_vec = absl::StrSplit(opts.output_names, ","); paddle_infer::Predictor* predictor = GetPredictor(); std::vector model_input_names = predictor->GetInputNames(); @@ -70,10 +72,13 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) { assert(output_names_vec[i] == model_output_names[i]); } ReleasePredictor(predictor); - InitCacheEncouts(opts); } +void PaddleNnet::Reset() { + InitCacheEncouts(opts_); +} + paddle_infer::Predictor* PaddleNnet::GetPredictor() { LOG(INFO) << "attempt to get a new predictor instance " << std::endl; paddle_infer::Predictor* predictor = nullptr; @@ -126,57 +131,71 @@ shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { } void PaddleNnet::FeedForward(const Matrix& features, Matrix* inferences) { - - paddle_infer::Predictor* predictor = GetPredictor(); - // 1. 得到所有的 input tensor 的名称 - int row = features.NumRows(); - int col = features.NumCols(); - std::vector input_names = predictor->GetInputNames(); - std::vector output_names = predictor->GetOutputNames(); - LOG(INFO) << "feat info: row=" << row << ", col=" << col; - - std::unique_ptr input_tensor = predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, row, col}; - input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(features.Data()); - // 3. 输入每个音频帧数 - std::unique_ptr input_len = predictor->GetInputHandle(input_names[1]); - std::vector input_len_size = {1}; - input_len->Reshape(input_len_size); - std::vector audio_len; - audio_len.push_back(row); - input_len->CopyFromCpu(audio_len.data()); - // 输入流式的缓存数据 - std::unique_ptr h_box = predictor->GetInputHandle(input_names[2]); - shared_ptr> h_cache = GetCacheEncoder(input_names[2]); - h_box->Reshape(h_cache->get_shape()); - h_box->CopyFromCpu(h_cache->get_data().data()); - std::unique_ptr c_box = predictor->GetInputHandle(input_names[3]); - shared_ptr> c_cache = GetCacheEncoder(input_names[3]); - c_box->Reshape(c_cache->get_shape()); - c_box->CopyFromCpu(c_cache->get_data().data()); - bool success = predictor->Run(); - - if (success == false) { - LOG(INFO) << "predictor run occurs error"; + paddle_infer::Predictor* predictor = GetPredictor(); + int row = features.NumRows(); + int col = features.NumCols(); + std::vector feed_feature; + // todo refactor feed feature: SmileGoat + feed_feature.reserve(row*col); + for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) { + for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) { + feed_feature.push_back(features(row_idx, col_idx)); } + } + std::vector input_names = predictor->GetInputNames(); + std::vector output_names = predictor->GetOutputNames(); + LOG(INFO) << "feat info: row=" << row << ", col= " << col; + + std::unique_ptr input_tensor = predictor->GetInputHandle(input_names[0]); + std::vector INPUT_SHAPE = {1, row, col}; + input_tensor->Reshape(INPUT_SHAPE); + input_tensor->CopyFromCpu(feed_feature.data()); + std::unique_ptr input_len = predictor->GetInputHandle(input_names[1]); + std::vector input_len_size = {1}; + input_len->Reshape(input_len_size); + std::vector audio_len; + audio_len.push_back(row); + input_len->CopyFromCpu(audio_len.data()); + + std::unique_ptr h_box = predictor->GetInputHandle(input_names[2]); + shared_ptr> h_cache = GetCacheEncoder(input_names[2]); + h_box->Reshape(h_cache->get_shape()); + h_box->CopyFromCpu(h_cache->get_data().data()); + std::unique_ptr c_box = predictor->GetInputHandle(input_names[3]); + shared_ptr> c_cache = GetCacheEncoder(input_names[3]); + c_box->Reshape(c_cache->get_shape()); + c_box->CopyFromCpu(c_cache->get_data().data()); + bool success = predictor->Run(); + + if (success == false) { + LOG(INFO) << "predictor run occurs error"; + } - LOG(INFO) << "get the model success"; - std::unique_ptr h_out = predictor->GetOutputHandle(output_names[2]); - assert(h_cache->get_shape() == h_out->shape()); - h_out->CopyToCpu(h_cache->get_data().data()); - std::unique_ptr c_out = predictor->GetOutputHandle(output_names[3]); - assert(c_cache->get_shape() == c_out->shape()); - c_out->CopyToCpu(c_cache->get_data().data()); - // 5. 得到最后的输出结果 - std::unique_ptr output_tensor = - predictor->GetOutputHandle(output_names[0]); - std::vector output_shape = output_tensor->shape(); - row = output_shape[1]; - col = output_shape[2]; - inferences->Resize(row, col); - output_tensor->CopyToCpu(inferences->Data()); - ReleasePredictor(predictor); + LOG(INFO) << "get the model success"; + std::unique_ptr h_out = predictor->GetOutputHandle(output_names[2]); + assert(h_cache->get_shape() == h_out->shape()); + h_out->CopyToCpu(h_cache->get_data().data()); + std::unique_ptr c_out = predictor->GetOutputHandle(output_names[3]); + assert(c_cache->get_shape() == c_out->shape()); + c_out->CopyToCpu(c_cache->get_data().data()); + + // get result + std::unique_ptr output_tensor = + predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_tensor->shape(); + row = output_shape[1]; + col = output_shape[2]; + vector inferences_result; + inferences->Resize(row, col); + inferences_result.resize(row*col); + output_tensor->CopyToCpu(inferences_result.data()); + ReleasePredictor(predictor); + + for (int row_idx = 0; row_idx < row; ++row_idx) { + for (int col_idx = 0; col_idx < col; ++col_idx) { + (*inferences)(row_idx, col_idx) = inferences_result[col*row_idx + col_idx]; + } + } } } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/paddle_nnet.h b/speechx/speechx/nnet/paddle_nnet.h index b659a12b..8eaced07 100644 --- a/speechx/speechx/nnet/paddle_nnet.h +++ b/speechx/speechx/nnet/paddle_nnet.h @@ -25,14 +25,14 @@ struct ModelOptions { bool enable_fc_padding; bool enable_profile; ModelOptions() : - model_path("model/final.zip"), - params_path("model/avg_1.jit.pdmodel"), + model_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"), + params_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"), thread_num(2), use_gpu(false), - input_names("audio"), - output_names("probs"), - cache_names("enouts"), - cache_shape("1-1-1"), + input_names("audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"), + output_names("save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"), + cache_names("chunk_state_h_box,chunk_state_c_box"), + cache_shape("3-1-1024,3-1-1024"), switch_ir_optim(false), enable_fc_padding(false), enable_profile(false) { @@ -87,6 +87,7 @@ class PaddleNnet : public NnetInterface { PaddleNnet(const ModelOptions& opts); virtual void FeedForward(const kaldi::Matrix& features, kaldi::Matrix* inferences); + virtual void Reset(); std::shared_ptr> GetCacheEncoder(const std::string& name); void InitCacheEncouts(const ModelOptions& opts); @@ -100,6 +101,7 @@ class PaddleNnet : public NnetInterface { std::map predictor_to_thread_id; std::map cache_names_idx_; std::vector>> cache_encouts_; + ModelOptions opts_; public: DISALLOW_COPY_AND_ASSIGN(PaddleNnet); diff --git a/speechx/speechx/utils/file_utils.cc b/speechx/speechx/utils/file_utils.cc index 8b2758ba..2d8be727 100644 --- a/speechx/speechx/utils/file_utils.cc +++ b/speechx/speechx/utils/file_utils.cc @@ -1,5 +1,7 @@ #include "utils/file_utils.h" +namespace ppspeech { + bool ReadFileToVector(const std::string& filename, std::vector* vocabulary) { std::ifstream file_in(filename); @@ -15,3 +17,5 @@ bool ReadFileToVector(const std::string& filename, return true; } + +} \ No newline at end of file