parent
b584b9690f
commit
406b4fc7d4
@ -0,0 +1 @@
|
|||||||
|
exclude_files=.*
|
@ -0,0 +1,228 @@
|
|||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// See www.openfst.org for extensive documentation on this weighted
|
||||||
|
// finite-state transducer library.
|
||||||
|
//
|
||||||
|
// Google-style flag handling declarations and inline definitions.
|
||||||
|
|
||||||
|
#ifndef FST_LIB_FLAGS_H_
|
||||||
|
#define FST_LIB_FLAGS_H_
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <fst/types.h>
|
||||||
|
#include <fst/lock.h>
|
||||||
|
|
||||||
|
#include "gflags/gflags.h"
|
||||||
|
#include "glog/logging.h"
|
||||||
|
|
||||||
|
using std::string;
|
||||||
|
|
||||||
|
// FLAGS USAGE:
|
||||||
|
//
|
||||||
|
// Definition example:
|
||||||
|
//
|
||||||
|
// DEFINE_int32(length, 0, "length");
|
||||||
|
//
|
||||||
|
// This defines variable FLAGS_length, initialized to 0.
|
||||||
|
//
|
||||||
|
// Declaration example:
|
||||||
|
//
|
||||||
|
// DECLARE_int32(length);
|
||||||
|
//
|
||||||
|
// SET_FLAGS() can be used to set flags from the command line
|
||||||
|
// using, for example, '--length=2'.
|
||||||
|
//
|
||||||
|
// ShowUsage() can be used to print out command and flag usage.
|
||||||
|
|
||||||
|
// #define DECLARE_bool(name) extern bool FLAGS_ ## name
|
||||||
|
// #define DECLARE_string(name) extern string FLAGS_ ## name
|
||||||
|
// #define DECLARE_int32(name) extern int32 FLAGS_ ## name
|
||||||
|
// #define DECLARE_int64(name) extern int64 FLAGS_ ## name
|
||||||
|
// #define DECLARE_double(name) extern double FLAGS_ ## name
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct FlagDescription {
|
||||||
|
FlagDescription(T *addr, const char *doc, const char *type,
|
||||||
|
const char *file, const T val)
|
||||||
|
: address(addr),
|
||||||
|
doc_string(doc),
|
||||||
|
type_name(type),
|
||||||
|
file_name(file),
|
||||||
|
default_value(val) {}
|
||||||
|
|
||||||
|
T *address;
|
||||||
|
const char *doc_string;
|
||||||
|
const char *type_name;
|
||||||
|
const char *file_name;
|
||||||
|
const T default_value;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FlagRegister {
|
||||||
|
public:
|
||||||
|
static FlagRegister<T> *GetRegister() {
|
||||||
|
static auto reg = new FlagRegister<T>;
|
||||||
|
return reg;
|
||||||
|
}
|
||||||
|
|
||||||
|
const FlagDescription<T> &GetFlagDescription(const string &name) const {
|
||||||
|
fst::MutexLock l(&flag_lock_);
|
||||||
|
auto it = flag_table_.find(name);
|
||||||
|
return it != flag_table_.end() ? it->second : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetDescription(const string &name,
|
||||||
|
const FlagDescription<T> &desc) {
|
||||||
|
fst::MutexLock l(&flag_lock_);
|
||||||
|
flag_table_.insert(make_pair(name, desc));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SetFlag(const string &val, bool *address) const {
|
||||||
|
if (val == "true" || val == "1" || val.empty()) {
|
||||||
|
*address = true;
|
||||||
|
return true;
|
||||||
|
} else if (val == "false" || val == "0") {
|
||||||
|
*address = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SetFlag(const string &val, string *address) const {
|
||||||
|
*address = val;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SetFlag(const string &val, int32 *address) const {
|
||||||
|
char *p = 0;
|
||||||
|
*address = strtol(val.c_str(), &p, 0);
|
||||||
|
return !val.empty() && *p == '\0';
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SetFlag(const string &val, int64 *address) const {
|
||||||
|
char *p = 0;
|
||||||
|
*address = strtoll(val.c_str(), &p, 0);
|
||||||
|
return !val.empty() && *p == '\0';
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SetFlag(const string &val, double *address) const {
|
||||||
|
char *p = 0;
|
||||||
|
*address = strtod(val.c_str(), &p);
|
||||||
|
return !val.empty() && *p == '\0';
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SetFlag(const string &arg, const string &val) const {
|
||||||
|
for (typename std::map< string, FlagDescription<T> >::const_iterator it =
|
||||||
|
flag_table_.begin();
|
||||||
|
it != flag_table_.end();
|
||||||
|
++it) {
|
||||||
|
const string &name = it->first;
|
||||||
|
const FlagDescription<T> &desc = it->second;
|
||||||
|
if (arg == name)
|
||||||
|
return SetFlag(val, desc.address);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetUsage(std::set<std::pair<string, string>> *usage_set) const {
|
||||||
|
for (auto it = flag_table_.begin(); it != flag_table_.end(); ++it) {
|
||||||
|
const string &name = it->first;
|
||||||
|
const FlagDescription<T> &desc = it->second;
|
||||||
|
string usage = " --" + name;
|
||||||
|
usage += ": type = ";
|
||||||
|
usage += desc.type_name;
|
||||||
|
usage += ", default = ";
|
||||||
|
usage += GetDefault(desc.default_value) + "\n ";
|
||||||
|
usage += desc.doc_string;
|
||||||
|
usage_set->insert(make_pair(desc.file_name, usage));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
string GetDefault(bool default_value) const {
|
||||||
|
return default_value ? "true" : "false";
|
||||||
|
}
|
||||||
|
|
||||||
|
string GetDefault(const string &default_value) const {
|
||||||
|
return "\"" + default_value + "\"";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class V>
|
||||||
|
string GetDefault(const V &default_value) const {
|
||||||
|
std::ostringstream strm;
|
||||||
|
strm << default_value;
|
||||||
|
return strm.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
mutable fst::Mutex flag_lock_; // Multithreading lock.
|
||||||
|
std::map<string, FlagDescription<T>> flag_table_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FlagRegisterer {
|
||||||
|
public:
|
||||||
|
FlagRegisterer(const string &name, const FlagDescription<T> &desc) {
|
||||||
|
auto registr = FlagRegister<T>::GetRegister();
|
||||||
|
registr->SetDescription(name, desc);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FlagRegisterer(const FlagRegisterer &) = delete;
|
||||||
|
FlagRegisterer &operator=(const FlagRegisterer &) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#define DEFINE_VAR(type, name, value, doc) \
|
||||||
|
type FLAGS_ ## name = value; \
|
||||||
|
static FlagRegisterer<type> \
|
||||||
|
name ## _flags_registerer(#name, FlagDescription<type>(&FLAGS_ ## name, \
|
||||||
|
doc, \
|
||||||
|
#type, \
|
||||||
|
__FILE__, \
|
||||||
|
value))
|
||||||
|
|
||||||
|
// #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc)
|
||||||
|
// #define DEFINE_string(name, value, doc) \
|
||||||
|
// DEFINE_VAR(string, name, value, doc)
|
||||||
|
// #define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc)
|
||||||
|
// #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc)
|
||||||
|
// #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc)
|
||||||
|
|
||||||
|
|
||||||
|
// Temporary directory.
|
||||||
|
DECLARE_string(tmpdir);
|
||||||
|
|
||||||
|
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags,
|
||||||
|
const char *src = "");
|
||||||
|
|
||||||
|
#define SET_FLAGS(usage, argc, argv, rmflags) \
|
||||||
|
gflags::ParseCommandLineFlags(argc, argv, true)
|
||||||
|
// SetFlags(usage, argc, argv, rmflags, __FILE__)
|
||||||
|
|
||||||
|
// Deprecated; for backward compatibility.
|
||||||
|
inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) {
|
||||||
|
return SetFlags(usage, argc, argv, rmflags);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShowUsage(bool long_usage = true);
|
||||||
|
|
||||||
|
#endif // FST_LIB_FLAGS_H_
|
@ -0,0 +1,82 @@
|
|||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// See www.openfst.org for extensive documentation on this weighted
|
||||||
|
// finite-state transducer library.
|
||||||
|
//
|
||||||
|
// Google-style logging declarations and inline definitions.
|
||||||
|
|
||||||
|
#ifndef FST_LIB_LOG_H_
|
||||||
|
#define FST_LIB_LOG_H_
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <fst/types.h>
|
||||||
|
#include <fst/flags.h>
|
||||||
|
|
||||||
|
using std::string;
|
||||||
|
|
||||||
|
DECLARE_int32(v);
|
||||||
|
|
||||||
|
class LogMessage {
|
||||||
|
public:
|
||||||
|
LogMessage(const string &type) : fatal_(type == "FATAL") {
|
||||||
|
std::cerr << type << ": ";
|
||||||
|
}
|
||||||
|
~LogMessage() {
|
||||||
|
std::cerr << std::endl;
|
||||||
|
if(fatal_)
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
std::ostream &stream() { return std::cerr; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool fatal_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// #define LOG(type) LogMessage(#type).stream()
|
||||||
|
// #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO)
|
||||||
|
|
||||||
|
// Checks
|
||||||
|
inline void FstCheck(bool x, const char* expr,
|
||||||
|
const char *file, int line) {
|
||||||
|
if (!x) {
|
||||||
|
LOG(FATAL) << "Check failed: \"" << expr
|
||||||
|
<< "\" file: " << file
|
||||||
|
<< " line: " << line;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// #define CHECK(x) FstCheck(static_cast<bool>(x), #x, __FILE__, __LINE__)
|
||||||
|
// #define CHECK_EQ(x, y) CHECK((x) == (y))
|
||||||
|
// #define CHECK_LT(x, y) CHECK((x) < (y))
|
||||||
|
// #define CHECK_GT(x, y) CHECK((x) > (y))
|
||||||
|
// #define CHECK_LE(x, y) CHECK((x) <= (y))
|
||||||
|
// #define CHECK_GE(x, y) CHECK((x) >= (y))
|
||||||
|
// #define CHECK_NE(x, y) CHECK((x) != (y))
|
||||||
|
|
||||||
|
// Debug checks
|
||||||
|
// #define DCHECK(x) assert(x)
|
||||||
|
// #define DCHECK_EQ(x, y) DCHECK((x) == (y))
|
||||||
|
// #define DCHECK_LT(x, y) DCHECK((x) < (y))
|
||||||
|
// #define DCHECK_GT(x, y) DCHECK((x) > (y))
|
||||||
|
// #define DCHECK_LE(x, y) DCHECK((x) <= (y))
|
||||||
|
// #define DCHECK_GE(x, y) DCHECK((x) >= (y))
|
||||||
|
// #define DCHECK_NE(x, y) DCHECK((x) != (y))
|
||||||
|
|
||||||
|
|
||||||
|
// Ports
|
||||||
|
#define ATTRIBUTE_DEPRECATED __attribute__((deprecated))
|
||||||
|
|
||||||
|
#endif // FST_LIB_LOG_H_
|
@ -0,0 +1,166 @@
|
|||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Google-style flag handling definitions.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
#if _MSC_VER
|
||||||
|
#include <io.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <fst/compat.h>
|
||||||
|
#include <fst/flags.h>
|
||||||
|
|
||||||
|
static const char *private_tmpdir = getenv("TMPDIR");
|
||||||
|
|
||||||
|
// DEFINE_int32(v, 0, "verbosity level");
|
||||||
|
// DEFINE_bool(help, false, "show usage information");
|
||||||
|
// DEFINE_bool(helpshort, false, "show brief usage information");
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp",
|
||||||
|
"temporary directory");
|
||||||
|
#else
|
||||||
|
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"),
|
||||||
|
"temporary directory");
|
||||||
|
#endif // !_MSC_VER
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
static string flag_usage;
|
||||||
|
static string prog_src;
|
||||||
|
|
||||||
|
// Sets prog_src to src.
|
||||||
|
static void SetProgSrc(const char *src) {
|
||||||
|
prog_src = src;
|
||||||
|
#if _MSC_VER
|
||||||
|
// This common code is invoked by all FST binaries, and only by them. Switch
|
||||||
|
// stdin and stdout into "binary" mode, so that 0x0A won't be translated into
|
||||||
|
// a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are
|
||||||
|
// already using ios::binary where binary files are read or written.
|
||||||
|
// Kudos to @daanzu for the suggested fix.
|
||||||
|
// https://github.com/kkm000/openfst/issues/20
|
||||||
|
// https://github.com/kkm000/openfst/pull/23
|
||||||
|
// https://github.com/kkm000/openfst/pull/32
|
||||||
|
_setmode(_fileno(stdin), O_BINARY);
|
||||||
|
_setmode(_fileno(stdout), O_BINARY);
|
||||||
|
#endif
|
||||||
|
// Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags()
|
||||||
|
// is called in fstx-main.cc, which results in a filename mismatch in
|
||||||
|
// ShowUsageRestrict() below.
|
||||||
|
static constexpr char kMainSuffix[] = "-main.cc";
|
||||||
|
const int prefix_length = prog_src.size() - strlen(kMainSuffix);
|
||||||
|
if (prefix_length > 0 && prog_src.substr(prefix_length) == kMainSuffix) {
|
||||||
|
prog_src.erase(prefix_length, strlen("-main"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetFlags(const char *usage, int *argc, char ***argv,
|
||||||
|
bool remove_flags, const char *src) {
|
||||||
|
flag_usage = usage;
|
||||||
|
SetProgSrc(src);
|
||||||
|
|
||||||
|
int index = 1;
|
||||||
|
for (; index < *argc; ++index) {
|
||||||
|
string argval = (*argv)[index];
|
||||||
|
if (argval[0] != '-' || argval == "-") break;
|
||||||
|
while (argval[0] == '-') argval = argval.substr(1); // Removes initial '-'.
|
||||||
|
string arg = argval;
|
||||||
|
string val = "";
|
||||||
|
// Splits argval (arg=val) into arg and val.
|
||||||
|
auto pos = argval.find("=");
|
||||||
|
if (pos != string::npos) {
|
||||||
|
arg = argval.substr(0, pos);
|
||||||
|
val = argval.substr(pos + 1);
|
||||||
|
}
|
||||||
|
auto bool_register = FlagRegister<bool>::GetRegister();
|
||||||
|
if (bool_register->SetFlag(arg, val))
|
||||||
|
continue;
|
||||||
|
auto string_register = FlagRegister<string>::GetRegister();
|
||||||
|
if (string_register->SetFlag(arg, val))
|
||||||
|
continue;
|
||||||
|
auto int32_register = FlagRegister<int32>::GetRegister();
|
||||||
|
if (int32_register->SetFlag(arg, val))
|
||||||
|
continue;
|
||||||
|
auto int64_register = FlagRegister<int64>::GetRegister();
|
||||||
|
if (int64_register->SetFlag(arg, val))
|
||||||
|
continue;
|
||||||
|
auto double_register = FlagRegister<double>::GetRegister();
|
||||||
|
if (double_register->SetFlag(arg, val))
|
||||||
|
continue;
|
||||||
|
LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index];
|
||||||
|
}
|
||||||
|
if (remove_flags) {
|
||||||
|
for (auto i = 0; i < *argc - index; ++i) {
|
||||||
|
(*argv)[i + 1] = (*argv)[i + index];
|
||||||
|
}
|
||||||
|
*argc -= index - 1;
|
||||||
|
}
|
||||||
|
// if (FLAGS_help) {
|
||||||
|
// ShowUsage(true);
|
||||||
|
// exit(1);
|
||||||
|
// }
|
||||||
|
// if (FLAGS_helpshort) {
|
||||||
|
// ShowUsage(false);
|
||||||
|
// exit(1);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
// If flag is defined in file 'src' and 'in_src' true or is not
|
||||||
|
// defined in file 'src' and 'in_src' is false, then print usage.
|
||||||
|
static void
|
||||||
|
ShowUsageRestrict(const std::set<pair<string, string>> &usage_set,
|
||||||
|
const string &src, bool in_src, bool show_file) {
|
||||||
|
string old_file;
|
||||||
|
bool file_out = false;
|
||||||
|
bool usage_out = false;
|
||||||
|
for (const auto &pair : usage_set) {
|
||||||
|
const auto &file = pair.first;
|
||||||
|
const auto &usage = pair.second;
|
||||||
|
bool match = file == src;
|
||||||
|
if ((match && !in_src) || (!match && in_src)) continue;
|
||||||
|
if (file != old_file) {
|
||||||
|
if (show_file) {
|
||||||
|
if (file_out) cout << "\n";
|
||||||
|
cout << "Flags from: " << file << "\n";
|
||||||
|
file_out = true;
|
||||||
|
}
|
||||||
|
old_file = file;
|
||||||
|
}
|
||||||
|
cout << usage << "\n";
|
||||||
|
usage_out = true;
|
||||||
|
}
|
||||||
|
if (usage_out) cout << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShowUsage(bool long_usage) {
|
||||||
|
std::set<pair<string, string>> usage_set;
|
||||||
|
cout << flag_usage << "\n";
|
||||||
|
auto bool_register = FlagRegister<bool>::GetRegister();
|
||||||
|
bool_register->GetUsage(&usage_set);
|
||||||
|
auto string_register = FlagRegister<string>::GetRegister();
|
||||||
|
string_register->GetUsage(&usage_set);
|
||||||
|
auto int32_register = FlagRegister<int32>::GetRegister();
|
||||||
|
int32_register->GetUsage(&usage_set);
|
||||||
|
auto int64_register = FlagRegister<int64>::GetRegister();
|
||||||
|
int64_register->GetUsage(&usage_set);
|
||||||
|
auto double_register = FlagRegister<double>::GetRegister();
|
||||||
|
double_register->GetUsage(&usage_set);
|
||||||
|
if (!prog_src.empty()) {
|
||||||
|
cout << "PROGRAM FLAGS:\n\n";
|
||||||
|
ShowUsageRestrict(usage_set, prog_src, true, false);
|
||||||
|
}
|
||||||
|
if (!long_usage) return;
|
||||||
|
if (!prog_src.empty()) cout << "LIBRARY FLAGS:\n\n";
|
||||||
|
ShowUsageRestrict(usage_set, prog_src, false, true);
|
||||||
|
}
|
@ -0,0 +1,134 @@
|
|||||||
|
#include "paddle_inference_api.h"
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <thread>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iterator>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
void produce_data(std::vector<std::vector<float>>* data);
|
||||||
|
void model_forward_test();
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
model_forward_test();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void model_forward_test() {
|
||||||
|
std::cout << "1. read the data" << std::endl;
|
||||||
|
std::vector<std::vector<float>> feats;
|
||||||
|
produce_data(&feats);
|
||||||
|
|
||||||
|
std::cout << "2. load the model" << std::endl;;
|
||||||
|
std::string model_graph = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel";
|
||||||
|
std::string model_params = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams";
|
||||||
|
paddle_infer::Config config;
|
||||||
|
config.SetModel(model_graph, model_params);
|
||||||
|
config.SwitchIrOptim(false);
|
||||||
|
config.DisableFCPadding();
|
||||||
|
auto predictor = paddle_infer::CreatePredictor(config);
|
||||||
|
|
||||||
|
std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl;
|
||||||
|
std::vector<float> paddle_input_feature_matrix;
|
||||||
|
for(const auto& item : feats) {
|
||||||
|
paddle_input_feature_matrix.insert(paddle_input_feature_matrix.end(), item.begin(), item.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "4. fead the data to model" << std::endl;
|
||||||
|
int row = feats.size();
|
||||||
|
int col = feats[0].size();
|
||||||
|
std::vector<std::string> input_names = predictor->GetInputNames();
|
||||||
|
std::vector<std::string> output_names = predictor->GetOutputNames();
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> input_tensor =
|
||||||
|
predictor->GetInputHandle(input_names[0]);
|
||||||
|
std::vector<int> INPUT_SHAPE = {1, row, col};
|
||||||
|
input_tensor->Reshape(INPUT_SHAPE);
|
||||||
|
input_tensor->CopyFromCpu(paddle_input_feature_matrix.data());
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
|
||||||
|
std::vector<int> input_len_size = {1};
|
||||||
|
input_len->Reshape(input_len_size);
|
||||||
|
std::vector<int64_t> audio_len;
|
||||||
|
audio_len.push_back(row);
|
||||||
|
input_len->CopyFromCpu(audio_len.data());
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = predictor->GetInputHandle(input_names[2]);
|
||||||
|
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
|
||||||
|
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
|
||||||
|
int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
|
||||||
|
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = predictor->GetInputHandle(input_names[3]);
|
||||||
|
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
|
||||||
|
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
|
||||||
|
int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
|
||||||
|
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
|
||||||
|
|
||||||
|
bool success = predictor->Run();
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
|
||||||
|
std::vector<int> h_out_shape = h_out->shape();
|
||||||
|
int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
std::vector<float> h_out_data(h_out_size);
|
||||||
|
h_out->CopyToCpu(h_out_data.data());
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
|
||||||
|
std::vector<int> c_out_shape = c_out->shape();
|
||||||
|
int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
std::vector<float> c_out_data(c_out_size);
|
||||||
|
c_out->CopyToCpu(c_out_data.data());
|
||||||
|
|
||||||
|
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
||||||
|
predictor->GetOutputHandle(output_names[0]);
|
||||||
|
std::vector<int> output_shape = output_tensor->shape();
|
||||||
|
std::vector<float> output_probs;
|
||||||
|
int output_size = std::accumulate(output_shape.begin(), output_shape.end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
output_probs.resize(output_size);
|
||||||
|
output_tensor->CopyToCpu(output_probs.data());
|
||||||
|
row = output_shape[1];
|
||||||
|
col = output_shape[2];
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> probs;
|
||||||
|
probs.reserve(row);
|
||||||
|
for (int i = 0; i < row; i++) {
|
||||||
|
probs.push_back(std::vector<float>());
|
||||||
|
probs.back().reserve(col);
|
||||||
|
|
||||||
|
for (int j = 0; j < col; j++) {
|
||||||
|
probs.back().push_back(output_probs[i * col + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> log_feat = probs;
|
||||||
|
std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl;
|
||||||
|
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
|
||||||
|
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) {
|
||||||
|
std::cout << log_feat[row_idx][col_idx] << " ";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void produce_data(std::vector<std::vector<float>>* data) {
|
||||||
|
int chunk_size = 35;
|
||||||
|
int col_size = 161;
|
||||||
|
data->reserve(chunk_size);
|
||||||
|
data->back().reserve(col_size);
|
||||||
|
for (int row = 0; row < chunk_size; ++row) {
|
||||||
|
data->push_back(std::vector<float>());
|
||||||
|
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
|
||||||
|
data->back().push_back(0.201);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,10 +1,10 @@
|
|||||||
project(decoder)
|
project(decoder)
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
|
||||||
add_library(decoder
|
add_library(decoder STATIC
|
||||||
ctc_beam_search_decoder.cc
|
ctc_beam_search_decoder.cc
|
||||||
ctc_decoders/decoder_utils.cpp
|
ctc_decoders/decoder_utils.cpp
|
||||||
ctc_decoders/path_trie.cpp
|
ctc_decoders/path_trie.cpp
|
||||||
ctc_decoders/scorer.cpp
|
ctc_decoders/scorer.cpp
|
||||||
)
|
)
|
||||||
target_link_libraries(decoder kenlm)
|
target_link_libraries(decoder PUBLIC kenlm utils fst)
|
@ -1,8 +1,8 @@
|
|||||||
project(frontend)
|
project(frontend)
|
||||||
|
|
||||||
add_library(frontend
|
add_library(frontend STATIC
|
||||||
normalizer.cc
|
normalizer.cc
|
||||||
linear_spectrogram.cc
|
linear_spectrogram.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(frontend kaldi-matrix)
|
target_link_libraries(frontend PUBLIC kaldi-matrix)
|
@ -1,39 +0,0 @@
|
|||||||
# Copyright (c) 2020 PeachLab. All Rights Reserved.
|
|
||||||
# Author : goat.zhou@qq.com (Yang Zhou)
|
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:public"])
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = 'kaldi-matrix',
|
|
||||||
srcs = [
|
|
||||||
'compressed-matrix.cc',
|
|
||||||
'kaldi-matrix.cc',
|
|
||||||
'kaldi-vector.cc',
|
|
||||||
'matrix-functions.cc',
|
|
||||||
'optimization.cc',
|
|
||||||
'packed-matrix.cc',
|
|
||||||
'qr.cc',
|
|
||||||
'sparse-matrix.cc',
|
|
||||||
'sp-matrix.cc',
|
|
||||||
'srfft.cc',
|
|
||||||
'tp-matrix.cc',
|
|
||||||
],
|
|
||||||
hdrs = glob(["*.h"]),
|
|
||||||
deps = [
|
|
||||||
'//base:kaldi-base',
|
|
||||||
'//common/third_party/openblas:openblas',
|
|
||||||
],
|
|
||||||
linkopts=['-lgfortran'],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = 'matrix-lib-test',
|
|
||||||
srcs = [
|
|
||||||
'matrix-lib-test.cc',
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
':kaldi-matrix',
|
|
||||||
'//util:kaldi-util',
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
@ -1,2 +1,7 @@
|
|||||||
aux_source_directory(. DIR_LIB_SRCS)
|
project(nnet)
|
||||||
add_library(nnet STATIC ${DIR_LIB_SRCS})
|
|
||||||
|
add_library(nnet STATIC
|
||||||
|
decodable.cc
|
||||||
|
paddle_nnet.cc
|
||||||
|
)
|
||||||
|
target_link_libraries(nnet absl::strings)
|
Loading…
Reference in new issue