Merge pull request #1521 from SmileGoat/speechx_goat

[Speechx] align linear_feature & nnet & refactor cmakelist
pull/1531/head
Hui Zhang 3 years ago committed by GitHub
commit ca5ed023db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,7 +26,6 @@ option(TEST_DEBUG "option for debug" OFF)
# Include third party
###############################################################################
# #example for include third party
# FetchContent_Declare()
# # FetchContent_MakeAvailable was not added until CMake 3.14
# FetchContent_MakeAvailable()
# include_directories()
@ -50,20 +49,25 @@ include_directories(${absl_SOURCE_DIR})
#)
#FetchContent_MakeAvailable(libsndfile)
# todo boost build
#include(FetchContent)
#FetchContent_Declare(
# Boost
# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.zip
# URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
#)
#FetchContent_MakeAvailable(Boost)
#include_directories(${Boost_SOURCE_DIR})
#boost
set(boost_SOURCE_DIR ${fc_patch}/boost-src)
set(boost_PREFIX_DIR ${fc_patch}/boost-subbuild/boost-prefix)
include(ExternalProject)
ExternalProject_Add(
boost
URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
SOURCE_DIR ${boost_SOURCE_DIR}
PREFIX ${boost_PREFIX_DIR}
BUILD_IN_SOURCE 1
CONFIGURE_COMMAND ./bootstrap.sh
BUILD_COMMAND ./b2
INSTALL_COMMAND ""
)
link_directories(${boost_SOURCE_DIR}/stage/lib)
include_directories(${boost_SOURCE_DIR})
set(BOOST_ROOT ${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
include_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
link_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0/stage/lib)
set(BOOST_ROOT ${boost_SOURCE_DIR})
include(FetchContent)
FetchContent_Declare(
kenlm
@ -71,9 +75,10 @@ FetchContent_Declare(
GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
)
FetchContent_MakeAvailable(kenlm)
add_dependencies(kenlm Boost)
add_dependencies(kenlm boost)
include_directories(${kenlm_SOURCE_DIR})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
# gflags
FetchContent_Declare(
gflags
@ -106,19 +111,34 @@ set(openfst_PREFIX_DIR ${fc_patch}/openfst-subbuild/openfst-populate-prefix)
ExternalProject_Add(openfst
URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
# PREFIX ${openfst_PREFIX_DIR}
SOURCE_DIR ${openfst_SOURCE_DIR}
BINARY_DIR ${openfst_BINARY_DIR}
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread"
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
BUILD_COMMAND make -j 4
)
add_dependencies(openfst gflags glog)
link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include)
set(PADDLE_LIB ${fc_patch}/paddle-lib/paddle_inference)
# paddle lib
set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
ExternalProject_Add(paddle
URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
PREFIX ${paddle_PREFIX_DIR}
SOURCE_DIR ${paddle_SOURCE_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)
set(PADDLE_LIB ${fc_patch}/paddle-lib)
include_directories("${PADDLE_LIB}/paddle/include")
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
@ -133,6 +153,23 @@ link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${PADDLE_LIB}/paddle/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
##paddle with mkl
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
include_directories("${MATH_LIB_PATH}/include")
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf xxhash cryptopp
${EXTERNAL_LIB})
add_subdirectory(speechx)

@ -0,0 +1 @@
exclude_files=.*

@ -0,0 +1,228 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Google-style flag handling declarations and inline definitions.
#ifndef FST_LIB_FLAGS_H_
#define FST_LIB_FLAGS_H_
#include <cstdlib>
#include <iostream>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <fst/types.h>
#include <fst/lock.h>
#include "gflags/gflags.h"
#include "glog/logging.h"
using std::string;
// FLAGS USAGE:
//
// Definition example:
//
// DEFINE_int32(length, 0, "length");
//
// This defines variable FLAGS_length, initialized to 0.
//
// Declaration example:
//
// DECLARE_int32(length);
//
// SET_FLAGS() can be used to set flags from the command line
// using, for example, '--length=2'.
//
// ShowUsage() can be used to print out command and flag usage.
// #define DECLARE_bool(name) extern bool FLAGS_ ## name
// #define DECLARE_string(name) extern string FLAGS_ ## name
// #define DECLARE_int32(name) extern int32 FLAGS_ ## name
// #define DECLARE_int64(name) extern int64 FLAGS_ ## name
// #define DECLARE_double(name) extern double FLAGS_ ## name
template <typename T>
struct FlagDescription {
FlagDescription(T *addr, const char *doc, const char *type,
const char *file, const T val)
: address(addr),
doc_string(doc),
type_name(type),
file_name(file),
default_value(val) {}
T *address;
const char *doc_string;
const char *type_name;
const char *file_name;
const T default_value;
};
template <typename T>
class FlagRegister {
public:
static FlagRegister<T> *GetRegister() {
static auto reg = new FlagRegister<T>;
return reg;
}
const FlagDescription<T> &GetFlagDescription(const string &name) const {
fst::MutexLock l(&flag_lock_);
auto it = flag_table_.find(name);
return it != flag_table_.end() ? it->second : 0;
}
void SetDescription(const string &name,
const FlagDescription<T> &desc) {
fst::MutexLock l(&flag_lock_);
flag_table_.insert(make_pair(name, desc));
}
bool SetFlag(const string &val, bool *address) const {
if (val == "true" || val == "1" || val.empty()) {
*address = true;
return true;
} else if (val == "false" || val == "0") {
*address = false;
return true;
}
else {
return false;
}
}
bool SetFlag(const string &val, string *address) const {
*address = val;
return true;
}
bool SetFlag(const string &val, int32 *address) const {
char *p = 0;
*address = strtol(val.c_str(), &p, 0);
return !val.empty() && *p == '\0';
}
bool SetFlag(const string &val, int64 *address) const {
char *p = 0;
*address = strtoll(val.c_str(), &p, 0);
return !val.empty() && *p == '\0';
}
bool SetFlag(const string &val, double *address) const {
char *p = 0;
*address = strtod(val.c_str(), &p);
return !val.empty() && *p == '\0';
}
bool SetFlag(const string &arg, const string &val) const {
for (typename std::map< string, FlagDescription<T> >::const_iterator it =
flag_table_.begin();
it != flag_table_.end();
++it) {
const string &name = it->first;
const FlagDescription<T> &desc = it->second;
if (arg == name)
return SetFlag(val, desc.address);
}
return false;
}
void GetUsage(std::set<std::pair<string, string>> *usage_set) const {
for (auto it = flag_table_.begin(); it != flag_table_.end(); ++it) {
const string &name = it->first;
const FlagDescription<T> &desc = it->second;
string usage = " --" + name;
usage += ": type = ";
usage += desc.type_name;
usage += ", default = ";
usage += GetDefault(desc.default_value) + "\n ";
usage += desc.doc_string;
usage_set->insert(make_pair(desc.file_name, usage));
}
}
private:
string GetDefault(bool default_value) const {
return default_value ? "true" : "false";
}
string GetDefault(const string &default_value) const {
return "\"" + default_value + "\"";
}
template <class V>
string GetDefault(const V &default_value) const {
std::ostringstream strm;
strm << default_value;
return strm.str();
}
mutable fst::Mutex flag_lock_; // Multithreading lock.
std::map<string, FlagDescription<T>> flag_table_;
};
template <typename T>
class FlagRegisterer {
public:
FlagRegisterer(const string &name, const FlagDescription<T> &desc) {
auto registr = FlagRegister<T>::GetRegister();
registr->SetDescription(name, desc);
}
private:
FlagRegisterer(const FlagRegisterer &) = delete;
FlagRegisterer &operator=(const FlagRegisterer &) = delete;
};
#define DEFINE_VAR(type, name, value, doc) \
type FLAGS_ ## name = value; \
static FlagRegisterer<type> \
name ## _flags_registerer(#name, FlagDescription<type>(&FLAGS_ ## name, \
doc, \
#type, \
__FILE__, \
value))
// #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc)
// #define DEFINE_string(name, value, doc) \
// DEFINE_VAR(string, name, value, doc)
// #define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc)
// #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc)
// #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc)
// Temporary directory.
DECLARE_string(tmpdir);
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags,
const char *src = "");
#define SET_FLAGS(usage, argc, argv, rmflags) \
gflags::ParseCommandLineFlags(argc, argv, true)
// SetFlags(usage, argc, argv, rmflags, __FILE__)
// Deprecated; for backward compatibility.
inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) {
return SetFlags(usage, argc, argv, rmflags);
}
void ShowUsage(bool long_usage = true);
#endif // FST_LIB_FLAGS_H_

@ -0,0 +1,82 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Google-style logging declarations and inline definitions.
#ifndef FST_LIB_LOG_H_
#define FST_LIB_LOG_H_
#include <cassert>
#include <iostream>
#include <string>
#include <fst/types.h>
#include <fst/flags.h>
using std::string;
DECLARE_int32(v);
class LogMessage {
public:
LogMessage(const string &type) : fatal_(type == "FATAL") {
std::cerr << type << ": ";
}
~LogMessage() {
std::cerr << std::endl;
if(fatal_)
exit(1);
}
std::ostream &stream() { return std::cerr; }
private:
bool fatal_;
};
// #define LOG(type) LogMessage(#type).stream()
// #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO)
// Checks
inline void FstCheck(bool x, const char* expr,
const char *file, int line) {
if (!x) {
LOG(FATAL) << "Check failed: \"" << expr
<< "\" file: " << file
<< " line: " << line;
}
}
// #define CHECK(x) FstCheck(static_cast<bool>(x), #x, __FILE__, __LINE__)
// #define CHECK_EQ(x, y) CHECK((x) == (y))
// #define CHECK_LT(x, y) CHECK((x) < (y))
// #define CHECK_GT(x, y) CHECK((x) > (y))
// #define CHECK_LE(x, y) CHECK((x) <= (y))
// #define CHECK_GE(x, y) CHECK((x) >= (y))
// #define CHECK_NE(x, y) CHECK((x) != (y))
// Debug checks
// #define DCHECK(x) assert(x)
// #define DCHECK_EQ(x, y) DCHECK((x) == (y))
// #define DCHECK_LT(x, y) DCHECK((x) < (y))
// #define DCHECK_GT(x, y) DCHECK((x) > (y))
// #define DCHECK_LE(x, y) DCHECK((x) <= (y))
// #define DCHECK_GE(x, y) DCHECK((x) >= (y))
// #define DCHECK_NE(x, y) DCHECK((x) != (y))
// Ports
#define ATTRIBUTE_DEPRECATED __attribute__((deprecated))
#endif // FST_LIB_LOG_H_

@ -0,0 +1,166 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Google-style flag handling definitions.
#include <cstring>
#if _MSC_VER
#include <io.h>
#include <fcntl.h>
#endif
#include <fst/compat.h>
#include <fst/flags.h>
static const char *private_tmpdir = getenv("TMPDIR");
// DEFINE_int32(v, 0, "verbosity level");
// DEFINE_bool(help, false, "show usage information");
// DEFINE_bool(helpshort, false, "show brief usage information");
#ifndef _MSC_VER
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp",
"temporary directory");
#else
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"),
"temporary directory");
#endif // !_MSC_VER
using namespace std;
static string flag_usage;
static string prog_src;
// Sets prog_src to src.
static void SetProgSrc(const char *src) {
prog_src = src;
#if _MSC_VER
// This common code is invoked by all FST binaries, and only by them. Switch
// stdin and stdout into "binary" mode, so that 0x0A won't be translated into
// a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are
// already using ios::binary where binary files are read or written.
// Kudos to @daanzu for the suggested fix.
// https://github.com/kkm000/openfst/issues/20
// https://github.com/kkm000/openfst/pull/23
// https://github.com/kkm000/openfst/pull/32
_setmode(_fileno(stdin), O_BINARY);
_setmode(_fileno(stdout), O_BINARY);
#endif
// Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags()
// is called in fstx-main.cc, which results in a filename mismatch in
// ShowUsageRestrict() below.
static constexpr char kMainSuffix[] = "-main.cc";
const int prefix_length = prog_src.size() - strlen(kMainSuffix);
if (prefix_length > 0 && prog_src.substr(prefix_length) == kMainSuffix) {
prog_src.erase(prefix_length, strlen("-main"));
}
}
void SetFlags(const char *usage, int *argc, char ***argv,
bool remove_flags, const char *src) {
flag_usage = usage;
SetProgSrc(src);
int index = 1;
for (; index < *argc; ++index) {
string argval = (*argv)[index];
if (argval[0] != '-' || argval == "-") break;
while (argval[0] == '-') argval = argval.substr(1); // Removes initial '-'.
string arg = argval;
string val = "";
// Splits argval (arg=val) into arg and val.
auto pos = argval.find("=");
if (pos != string::npos) {
arg = argval.substr(0, pos);
val = argval.substr(pos + 1);
}
auto bool_register = FlagRegister<bool>::GetRegister();
if (bool_register->SetFlag(arg, val))
continue;
auto string_register = FlagRegister<string>::GetRegister();
if (string_register->SetFlag(arg, val))
continue;
auto int32_register = FlagRegister<int32>::GetRegister();
if (int32_register->SetFlag(arg, val))
continue;
auto int64_register = FlagRegister<int64>::GetRegister();
if (int64_register->SetFlag(arg, val))
continue;
auto double_register = FlagRegister<double>::GetRegister();
if (double_register->SetFlag(arg, val))
continue;
LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index];
}
if (remove_flags) {
for (auto i = 0; i < *argc - index; ++i) {
(*argv)[i + 1] = (*argv)[i + index];
}
*argc -= index - 1;
}
// if (FLAGS_help) {
// ShowUsage(true);
// exit(1);
// }
// if (FLAGS_helpshort) {
// ShowUsage(false);
// exit(1);
// }
}
// If flag is defined in file 'src' and 'in_src' true or is not
// defined in file 'src' and 'in_src' is false, then print usage.
static void
ShowUsageRestrict(const std::set<pair<string, string>> &usage_set,
const string &src, bool in_src, bool show_file) {
string old_file;
bool file_out = false;
bool usage_out = false;
for (const auto &pair : usage_set) {
const auto &file = pair.first;
const auto &usage = pair.second;
bool match = file == src;
if ((match && !in_src) || (!match && in_src)) continue;
if (file != old_file) {
if (show_file) {
if (file_out) cout << "\n";
cout << "Flags from: " << file << "\n";
file_out = true;
}
old_file = file;
}
cout << usage << "\n";
usage_out = true;
}
if (usage_out) cout << "\n";
}
void ShowUsage(bool long_usage) {
std::set<pair<string, string>> usage_set;
cout << flag_usage << "\n";
auto bool_register = FlagRegister<bool>::GetRegister();
bool_register->GetUsage(&usage_set);
auto string_register = FlagRegister<string>::GetRegister();
string_register->GetUsage(&usage_set);
auto int32_register = FlagRegister<int32>::GetRegister();
int32_register->GetUsage(&usage_set);
auto int64_register = FlagRegister<int64>::GetRegister();
int64_register->GetUsage(&usage_set);
auto double_register = FlagRegister<double>::GetRegister();
double_register->GetUsage(&usage_set);
if (!prog_src.empty()) {
cout << "PROGRAM FLAGS:\n\n";
ShowUsageRestrict(usage_set, prog_src, true, false);
}
if (!long_usage) return;
if (!prog_src.empty()) cout << "LIBRARY FLAGS:\n\n";
ShowUsageRestrict(usage_set, prog_src, false, true);
}

@ -2,10 +2,6 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(speechx LANGUAGES CXX)
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/kaldi
@ -37,10 +33,13 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder
add_subdirectory(decoder)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc)
target_link_libraries(mfcc-test kaldi-mfcc ${MATH_LIB})
add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
#add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
#target_link_libraries(offline_decoder_main nnet decoder gflags glog)
add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(model_test codelab/nnet_test/model_test.cc)
target_link_libraries(model_test PUBLIC nnet gflags ${DEPS})

@ -14,4 +14,4 @@
#pragma once
#include "gflags/gflags.h"
#include "fst/flags.h"

@ -14,4 +14,4 @@
#pragma once
#include "glog/logging.h"
#include "fst/log.h"

@ -4,16 +4,20 @@
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
int32 chunk_size,
std::vector<kaldi::Matrix<BaseFloat>> feature_chunks) {
//void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
}
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -24,31 +28,26 @@ int main(int argc, char* argv[]) {
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
CTCBeamSearchOptions opts;
CTCBeamSearch decoder(opts);
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ModelOptions model_opts;
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(model_opts));
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(new ppspeech::PaddleNnet(model_opts));
Decodable decodable();
decodable.SetNnet(nnet);
std::shared_ptr<ppspeech::Decodable> decodable(new ppspeech::Decodable(nnet));
int32 chunk_size = 0;
//int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
vector<Matrix<BaseFloat>> feature_chunks;
SplitFeature(feature, chunk_size, &feature_chunks);
for (auto feature_chunk : feature_chunks) {
decodable.FeedFeatures(feature_chunk);
decoder.InitDecoder();
decoder.AdvanceDecode(decodable, chunk_size);
}
decodable.InputFinished();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable.Reset();
decodable->Reset();
++num_done;
}

@ -11,6 +11,10 @@
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
std::vector<float> mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007};
std::vector<float> variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804};
int count_ = 912592;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
@ -21,20 +25,62 @@ int main(int argc, char* argv[]) {
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(waveform);
linear_spectrogram.ReadFeats(&features);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) = (feat(row_idx, col_idx) - mean_[col_idx])*variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
if (num_done % 50 == 0 && num_done != 0)

@ -0,0 +1,134 @@
#include "paddle_inference_api.h"
#include <gflags/gflags.h>
#include <iostream>
#include <thread>
#include <fstream>
#include <iterator>
#include <algorithm>
#include <numeric>
#include <functional>
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
model_forward_test();
return 0;
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;;
std::string model_graph = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel";
std::string model_params = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams";
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
config.DisableFCPadding();
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl;
std::vector<float> paddle_input_feature_matrix;
for(const auto& item : feats) {
paddle_input_feature_matrix.insert(paddle_input_feature_matrix.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(paddle_input_feature_matrix.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(),
1, std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(),
1, std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
bool success = predictor->Run();
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(),
1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(),
1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35;
int col_size = 161;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}

@ -1,10 +1,10 @@
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder
add_library(decoder STATIC
ctc_beam_search_decoder.cc
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
)
target_link_libraries(decoder kenlm)
target_link_libraries(decoder PUBLIC kenlm utils fst)

@ -1,8 +1,8 @@
project(frontend)
add_library(frontend
add_library(frontend STATIC
normalizer.cc
linear_spectrogram.cc
)
target_link_libraries(frontend kaldi-matrix)
target_link_libraries(frontend PUBLIC kaldi-matrix)

@ -47,6 +47,7 @@ void CopyStdVector2Vector_(const vector<BaseFloat>& input,
LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
opts_ = opts;
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift();
@ -105,7 +106,7 @@ void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result.size(); ++col_idx) {
for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) {
(*feats)(row_idx, col_idx) = result[row_idx][col_idx];
}
}
@ -133,7 +134,7 @@ bool LinearSpectrogram::Compute(const vector<float>& wave,
const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift();
const int& fft_points = fft_points_;
const float scale = hanning_window_energy_ * frame_shift;
const float scale = hanning_window_energy_ * sample_rate;
if (num_samples < frame_length) {
return true;
@ -153,10 +154,7 @@ bool LinearSpectrogram::Compute(const vector<float>& wave,
fft_img.clear();
fft_real.clear();
v.assign(data.begin(), data.end());
if (NumpyFft(&v, &fft_real, &fft_img)) {
LOG(ERROR)<< i << " fft compute occurs error, please checkout the input data";
return false;
}
NumpyFft(&v, &fft_real, &fft_img);
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) {

@ -29,7 +29,7 @@ class DecibelNormalizer : public FeatureExtractorInterface {
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return 0; }
virtual size_t Dim() const { return dim_; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
private:

@ -1,39 +0,0 @@
# Copyright (c) 2020 PeachLab. All Rights Reserved.
# Author : goat.zhou@qq.com (Yang Zhou)
package(default_visibility = ["//visibility:public"])
cc_library(
name = 'kaldi-matrix',
srcs = [
'compressed-matrix.cc',
'kaldi-matrix.cc',
'kaldi-vector.cc',
'matrix-functions.cc',
'optimization.cc',
'packed-matrix.cc',
'qr.cc',
'sparse-matrix.cc',
'sp-matrix.cc',
'srfft.cc',
'tp-matrix.cc',
],
hdrs = glob(["*.h"]),
deps = [
'//base:kaldi-base',
'//common/third_party/openblas:openblas',
],
linkopts=['-lgfortran'],
)
cc_binary(
name = 'matrix-lib-test',
srcs = [
'matrix-lib-test.cc',
],
deps = [
':kaldi-matrix',
'//util:kaldi-util',
],
)

@ -13,4 +13,4 @@ srfft.cc
tp-matrix.cc
)
target_link_libraries(kaldi-matrix gfortran kaldi-base libopenblas.a)
target_link_libraries(kaldi-matrix gfortran kaldi-base ${MATH_LIB})

@ -42,7 +42,7 @@
#define HAVE_OPENBLAS
#define HAVE_MKL
#if (defined(HAVE_CLAPACK) && (defined(HAVE_ATLAS) || defined(HAVE_MKL))) \
|| (defined(HAVE_ATLAS) && defined(HAVE_MKL))

@ -1,2 +1,7 @@
aux_source_directory(. DIR_LIB_SRCS)
add_library(nnet STATIC ${DIR_LIB_SRCS})
project(nnet)
add_library(nnet STATIC
decodable.cc
paddle_nnet.cc
)
target_link_libraries(nnet absl::strings)

@ -114,7 +114,7 @@ class DecodableInterface {
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0;
virtual ~DecodableInterface() {}
};

@ -38,6 +38,15 @@ void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
return ;
}
std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) {
std::vector<BaseFloat> result;
result.reserve(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
result[idx] = nnet_cache_(frame, idx);
}
return result;
}
void Decodable::Reset() {
// frontend_.Reset();
nnet_->Reset();

@ -15,9 +15,9 @@ class Decodable : public kaldi::DecodableInterface {
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Reset();
void InputFinished() { finished_ = true; }
private:

@ -9,11 +9,11 @@ namespace ppspeech {
class NnetInterface {
public:
virtual ~NnetInterface() {}
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset();
kaldi::Matrix<kaldi::BaseFloat>* inferences)= 0;
virtual void Reset() = 0;
virtual ~NnetInterface() {}
};
} // namespace ppspeech
} // namespace ppspeech

@ -10,14 +10,16 @@ using kaldi::Matrix;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ", ");
cache_names = absl::StrSplit(opts.cache_names, ",");
std::vector<std::string> cache_shapes;
cache_shapes = absl::StrSplit(opts.cache_shape, ", ");
cache_shapes = absl::StrSplit(opts.cache_shape, ",");
assert(cache_shapes.size() == cache_names.size());
cache_encouts_.clear();
cache_names_idx_.clear();
for (size_t i = 0; i < cache_shapes.size(); i++) {
std::vector<std::string> tmp_shape;
tmp_shape = absl::StrSplit(cache_shapes[i], "- ");
tmp_shape = absl::StrSplit(cache_shapes[i], "-");
std::vector<int> cur_shape;
std::transform(tmp_shape.begin(), tmp_shape.end(),
std::back_inserter(cur_shape),
@ -30,14 +32,14 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
}
}
PaddleNnet::PaddleNnet(const ModelOptions& opts) {
PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) {
config.EnableUseGpu(500, 0);
}
config.SwitchIrOptim(opts.switch_ir_optim);
if (opts.enable_fc_padding) {
if (opts.enable_fc_padding == false) {
config.DisableFCPadding();
}
if (opts.enable_profile) {
@ -54,8 +56,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) {
LOG(INFO) << "start to check the predictor input and output names";
LOG(INFO) << "input names: " << opts.input_names;
LOG(INFO) << "output names: " << opts.output_names;
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ", ");
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ", ");
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ",");
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ",");
paddle_infer::Predictor* predictor = GetPredictor();
std::vector<std::string> model_input_names = predictor->GetInputNames();
@ -70,10 +72,13 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) {
assert(output_names_vec[i] == model_output_names[i]);
}
ReleasePredictor(predictor);
InitCacheEncouts(opts);
}
void PaddleNnet::Reset() {
InitCacheEncouts(opts_);
}
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
paddle_infer::Predictor* predictor = nullptr;
@ -126,57 +131,71 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
}
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) {
paddle_infer::Predictor* predictor = GetPredictor();
// 1. 得到所有的 input tensor 的名称
int row = features.NumRows();
int col = features.NumCols();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col=" << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(features.Data());
// 3. 输入每个音频帧数
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
// 输入流式的缓存数据
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data());
bool success = predictor->Run();
if (success == false) {
LOG(INFO) << "predictor run occurs error";
paddle_infer::Predictor* predictor = GetPredictor();
int row = features.NumRows();
int col = features.NumCols();
std::vector<BaseFloat> feed_feature;
// todo refactor feed feature: SmileGoat
feed_feature.reserve(row*col);
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
}
}
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col= " << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data());
bool success = predictor->Run();
if (success == false) {
LOG(INFO) << "predictor run occurs error";
}
LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data());
// 5. 得到最后的输出结果
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
col = output_shape[2];
inferences->Resize(row, col);
output_tensor->CopyToCpu(inferences->Data());
ReleasePredictor(predictor);
LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data());
// get result
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
col = output_shape[2];
vector<float> inferences_result;
inferences->Resize(row, col);
inferences_result.resize(row*col);
output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) = inferences_result[col*row_idx + col_idx];
}
}
}
} // namespace ppspeech

@ -25,14 +25,14 @@ struct ModelOptions {
bool enable_fc_padding;
bool enable_profile;
ModelOptions() :
model_path("model/final.zip"),
params_path("model/avg_1.jit.pdmodel"),
model_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"),
params_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"),
thread_num(2),
use_gpu(false),
input_names("audio"),
output_names("probs"),
cache_names("enouts"),
cache_shape("1-1-1"),
input_names("audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"),
output_names("save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"),
cache_names("chunk_state_h_box,chunk_state_c_box"),
cache_shape("3-1-1024,3-1-1024"),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {
@ -87,6 +87,7 @@ class PaddleNnet : public NnetInterface {
PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
void InitCacheEncouts(const ModelOptions& opts);
@ -100,6 +101,7 @@ class PaddleNnet : public NnetInterface {
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
ModelOptions opts_;
public:
DISALLOW_COPY_AND_ASSIGN(PaddleNnet);

@ -1,5 +1,7 @@
#include "utils/file_utils.h"
namespace ppspeech {
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* vocabulary) {
std::ifstream file_in(filename);
@ -15,3 +17,5 @@ bool ReadFileToVector(const std::string& filename,
return true;
}
}
Loading…
Cancel
Save