parent
b584b9690f
commit
406b4fc7d4
@ -0,0 +1 @@
|
||||
exclude_files=.*
|
@ -0,0 +1,228 @@
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// See www.openfst.org for extensive documentation on this weighted
|
||||
// finite-state transducer library.
|
||||
//
|
||||
// Google-style flag handling declarations and inline definitions.
|
||||
|
||||
#ifndef FST_LIB_FLAGS_H_
|
||||
#define FST_LIB_FLAGS_H_
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include <fst/types.h>
|
||||
#include <fst/lock.h>
|
||||
|
||||
#include "gflags/gflags.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
using std::string;
|
||||
|
||||
// FLAGS USAGE:
|
||||
//
|
||||
// Definition example:
|
||||
//
|
||||
// DEFINE_int32(length, 0, "length");
|
||||
//
|
||||
// This defines variable FLAGS_length, initialized to 0.
|
||||
//
|
||||
// Declaration example:
|
||||
//
|
||||
// DECLARE_int32(length);
|
||||
//
|
||||
// SET_FLAGS() can be used to set flags from the command line
|
||||
// using, for example, '--length=2'.
|
||||
//
|
||||
// ShowUsage() can be used to print out command and flag usage.
|
||||
|
||||
// #define DECLARE_bool(name) extern bool FLAGS_ ## name
|
||||
// #define DECLARE_string(name) extern string FLAGS_ ## name
|
||||
// #define DECLARE_int32(name) extern int32 FLAGS_ ## name
|
||||
// #define DECLARE_int64(name) extern int64 FLAGS_ ## name
|
||||
// #define DECLARE_double(name) extern double FLAGS_ ## name
|
||||
|
||||
template <typename T>
|
||||
struct FlagDescription {
|
||||
FlagDescription(T *addr, const char *doc, const char *type,
|
||||
const char *file, const T val)
|
||||
: address(addr),
|
||||
doc_string(doc),
|
||||
type_name(type),
|
||||
file_name(file),
|
||||
default_value(val) {}
|
||||
|
||||
T *address;
|
||||
const char *doc_string;
|
||||
const char *type_name;
|
||||
const char *file_name;
|
||||
const T default_value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FlagRegister {
|
||||
public:
|
||||
static FlagRegister<T> *GetRegister() {
|
||||
static auto reg = new FlagRegister<T>;
|
||||
return reg;
|
||||
}
|
||||
|
||||
const FlagDescription<T> &GetFlagDescription(const string &name) const {
|
||||
fst::MutexLock l(&flag_lock_);
|
||||
auto it = flag_table_.find(name);
|
||||
return it != flag_table_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
void SetDescription(const string &name,
|
||||
const FlagDescription<T> &desc) {
|
||||
fst::MutexLock l(&flag_lock_);
|
||||
flag_table_.insert(make_pair(name, desc));
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, bool *address) const {
|
||||
if (val == "true" || val == "1" || val.empty()) {
|
||||
*address = true;
|
||||
return true;
|
||||
} else if (val == "false" || val == "0") {
|
||||
*address = false;
|
||||
return true;
|
||||
}
|
||||
else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, string *address) const {
|
||||
*address = val;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, int32 *address) const {
|
||||
char *p = 0;
|
||||
*address = strtol(val.c_str(), &p, 0);
|
||||
return !val.empty() && *p == '\0';
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, int64 *address) const {
|
||||
char *p = 0;
|
||||
*address = strtoll(val.c_str(), &p, 0);
|
||||
return !val.empty() && *p == '\0';
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, double *address) const {
|
||||
char *p = 0;
|
||||
*address = strtod(val.c_str(), &p);
|
||||
return !val.empty() && *p == '\0';
|
||||
}
|
||||
|
||||
bool SetFlag(const string &arg, const string &val) const {
|
||||
for (typename std::map< string, FlagDescription<T> >::const_iterator it =
|
||||
flag_table_.begin();
|
||||
it != flag_table_.end();
|
||||
++it) {
|
||||
const string &name = it->first;
|
||||
const FlagDescription<T> &desc = it->second;
|
||||
if (arg == name)
|
||||
return SetFlag(val, desc.address);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void GetUsage(std::set<std::pair<string, string>> *usage_set) const {
|
||||
for (auto it = flag_table_.begin(); it != flag_table_.end(); ++it) {
|
||||
const string &name = it->first;
|
||||
const FlagDescription<T> &desc = it->second;
|
||||
string usage = " --" + name;
|
||||
usage += ": type = ";
|
||||
usage += desc.type_name;
|
||||
usage += ", default = ";
|
||||
usage += GetDefault(desc.default_value) + "\n ";
|
||||
usage += desc.doc_string;
|
||||
usage_set->insert(make_pair(desc.file_name, usage));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string GetDefault(bool default_value) const {
|
||||
return default_value ? "true" : "false";
|
||||
}
|
||||
|
||||
string GetDefault(const string &default_value) const {
|
||||
return "\"" + default_value + "\"";
|
||||
}
|
||||
|
||||
template <class V>
|
||||
string GetDefault(const V &default_value) const {
|
||||
std::ostringstream strm;
|
||||
strm << default_value;
|
||||
return strm.str();
|
||||
}
|
||||
|
||||
mutable fst::Mutex flag_lock_; // Multithreading lock.
|
||||
std::map<string, FlagDescription<T>> flag_table_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FlagRegisterer {
|
||||
public:
|
||||
FlagRegisterer(const string &name, const FlagDescription<T> &desc) {
|
||||
auto registr = FlagRegister<T>::GetRegister();
|
||||
registr->SetDescription(name, desc);
|
||||
}
|
||||
|
||||
private:
|
||||
FlagRegisterer(const FlagRegisterer &) = delete;
|
||||
FlagRegisterer &operator=(const FlagRegisterer &) = delete;
|
||||
};
|
||||
|
||||
|
||||
#define DEFINE_VAR(type, name, value, doc) \
|
||||
type FLAGS_ ## name = value; \
|
||||
static FlagRegisterer<type> \
|
||||
name ## _flags_registerer(#name, FlagDescription<type>(&FLAGS_ ## name, \
|
||||
doc, \
|
||||
#type, \
|
||||
__FILE__, \
|
||||
value))
|
||||
|
||||
// #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc)
|
||||
// #define DEFINE_string(name, value, doc) \
|
||||
// DEFINE_VAR(string, name, value, doc)
|
||||
// #define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc)
|
||||
// #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc)
|
||||
// #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc)
|
||||
|
||||
|
||||
// Temporary directory.
|
||||
DECLARE_string(tmpdir);
|
||||
|
||||
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags,
|
||||
const char *src = "");
|
||||
|
||||
#define SET_FLAGS(usage, argc, argv, rmflags) \
|
||||
gflags::ParseCommandLineFlags(argc, argv, true)
|
||||
// SetFlags(usage, argc, argv, rmflags, __FILE__)
|
||||
|
||||
// Deprecated; for backward compatibility.
|
||||
inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) {
|
||||
return SetFlags(usage, argc, argv, rmflags);
|
||||
}
|
||||
|
||||
void ShowUsage(bool long_usage = true);
|
||||
|
||||
#endif // FST_LIB_FLAGS_H_
|
@ -0,0 +1,82 @@
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// See www.openfst.org for extensive documentation on this weighted
|
||||
// finite-state transducer library.
|
||||
//
|
||||
// Google-style logging declarations and inline definitions.
|
||||
|
||||
#ifndef FST_LIB_LOG_H_
|
||||
#define FST_LIB_LOG_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include <fst/types.h>
|
||||
#include <fst/flags.h>
|
||||
|
||||
using std::string;
|
||||
|
||||
DECLARE_int32(v);
|
||||
|
||||
class LogMessage {
|
||||
public:
|
||||
LogMessage(const string &type) : fatal_(type == "FATAL") {
|
||||
std::cerr << type << ": ";
|
||||
}
|
||||
~LogMessage() {
|
||||
std::cerr << std::endl;
|
||||
if(fatal_)
|
||||
exit(1);
|
||||
}
|
||||
std::ostream &stream() { return std::cerr; }
|
||||
|
||||
private:
|
||||
bool fatal_;
|
||||
};
|
||||
|
||||
// #define LOG(type) LogMessage(#type).stream()
|
||||
// #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO)
|
||||
|
||||
// Checks
|
||||
inline void FstCheck(bool x, const char* expr,
|
||||
const char *file, int line) {
|
||||
if (!x) {
|
||||
LOG(FATAL) << "Check failed: \"" << expr
|
||||
<< "\" file: " << file
|
||||
<< " line: " << line;
|
||||
}
|
||||
}
|
||||
|
||||
// #define CHECK(x) FstCheck(static_cast<bool>(x), #x, __FILE__, __LINE__)
|
||||
// #define CHECK_EQ(x, y) CHECK((x) == (y))
|
||||
// #define CHECK_LT(x, y) CHECK((x) < (y))
|
||||
// #define CHECK_GT(x, y) CHECK((x) > (y))
|
||||
// #define CHECK_LE(x, y) CHECK((x) <= (y))
|
||||
// #define CHECK_GE(x, y) CHECK((x) >= (y))
|
||||
// #define CHECK_NE(x, y) CHECK((x) != (y))
|
||||
|
||||
// Debug checks
|
||||
// #define DCHECK(x) assert(x)
|
||||
// #define DCHECK_EQ(x, y) DCHECK((x) == (y))
|
||||
// #define DCHECK_LT(x, y) DCHECK((x) < (y))
|
||||
// #define DCHECK_GT(x, y) DCHECK((x) > (y))
|
||||
// #define DCHECK_LE(x, y) DCHECK((x) <= (y))
|
||||
// #define DCHECK_GE(x, y) DCHECK((x) >= (y))
|
||||
// #define DCHECK_NE(x, y) DCHECK((x) != (y))
|
||||
|
||||
|
||||
// Ports
|
||||
#define ATTRIBUTE_DEPRECATED __attribute__((deprecated))
|
||||
|
||||
#endif // FST_LIB_LOG_H_
|
@ -0,0 +1,166 @@
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Google-style flag handling definitions.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#if _MSC_VER
|
||||
#include <io.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
|
||||
#include <fst/compat.h>
|
||||
#include <fst/flags.h>
|
||||
|
||||
static const char *private_tmpdir = getenv("TMPDIR");
|
||||
|
||||
// DEFINE_int32(v, 0, "verbosity level");
|
||||
// DEFINE_bool(help, false, "show usage information");
|
||||
// DEFINE_bool(helpshort, false, "show brief usage information");
|
||||
#ifndef _MSC_VER
|
||||
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp",
|
||||
"temporary directory");
|
||||
#else
|
||||
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"),
|
||||
"temporary directory");
|
||||
#endif // !_MSC_VER
|
||||
|
||||
using namespace std;
|
||||
|
||||
static string flag_usage;
|
||||
static string prog_src;
|
||||
|
||||
// Sets prog_src to src.
|
||||
static void SetProgSrc(const char *src) {
|
||||
prog_src = src;
|
||||
#if _MSC_VER
|
||||
// This common code is invoked by all FST binaries, and only by them. Switch
|
||||
// stdin and stdout into "binary" mode, so that 0x0A won't be translated into
|
||||
// a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are
|
||||
// already using ios::binary where binary files are read or written.
|
||||
// Kudos to @daanzu for the suggested fix.
|
||||
// https://github.com/kkm000/openfst/issues/20
|
||||
// https://github.com/kkm000/openfst/pull/23
|
||||
// https://github.com/kkm000/openfst/pull/32
|
||||
_setmode(_fileno(stdin), O_BINARY);
|
||||
_setmode(_fileno(stdout), O_BINARY);
|
||||
#endif
|
||||
// Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags()
|
||||
// is called in fstx-main.cc, which results in a filename mismatch in
|
||||
// ShowUsageRestrict() below.
|
||||
static constexpr char kMainSuffix[] = "-main.cc";
|
||||
const int prefix_length = prog_src.size() - strlen(kMainSuffix);
|
||||
if (prefix_length > 0 && prog_src.substr(prefix_length) == kMainSuffix) {
|
||||
prog_src.erase(prefix_length, strlen("-main"));
|
||||
}
|
||||
}
|
||||
|
||||
void SetFlags(const char *usage, int *argc, char ***argv,
|
||||
bool remove_flags, const char *src) {
|
||||
flag_usage = usage;
|
||||
SetProgSrc(src);
|
||||
|
||||
int index = 1;
|
||||
for (; index < *argc; ++index) {
|
||||
string argval = (*argv)[index];
|
||||
if (argval[0] != '-' || argval == "-") break;
|
||||
while (argval[0] == '-') argval = argval.substr(1); // Removes initial '-'.
|
||||
string arg = argval;
|
||||
string val = "";
|
||||
// Splits argval (arg=val) into arg and val.
|
||||
auto pos = argval.find("=");
|
||||
if (pos != string::npos) {
|
||||
arg = argval.substr(0, pos);
|
||||
val = argval.substr(pos + 1);
|
||||
}
|
||||
auto bool_register = FlagRegister<bool>::GetRegister();
|
||||
if (bool_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto string_register = FlagRegister<string>::GetRegister();
|
||||
if (string_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto int32_register = FlagRegister<int32>::GetRegister();
|
||||
if (int32_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto int64_register = FlagRegister<int64>::GetRegister();
|
||||
if (int64_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto double_register = FlagRegister<double>::GetRegister();
|
||||
if (double_register->SetFlag(arg, val))
|
||||
continue;
|
||||
LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index];
|
||||
}
|
||||
if (remove_flags) {
|
||||
for (auto i = 0; i < *argc - index; ++i) {
|
||||
(*argv)[i + 1] = (*argv)[i + index];
|
||||
}
|
||||
*argc -= index - 1;
|
||||
}
|
||||
// if (FLAGS_help) {
|
||||
// ShowUsage(true);
|
||||
// exit(1);
|
||||
// }
|
||||
// if (FLAGS_helpshort) {
|
||||
// ShowUsage(false);
|
||||
// exit(1);
|
||||
// }
|
||||
}
|
||||
|
||||
// If flag is defined in file 'src' and 'in_src' true or is not
|
||||
// defined in file 'src' and 'in_src' is false, then print usage.
|
||||
static void
|
||||
ShowUsageRestrict(const std::set<pair<string, string>> &usage_set,
|
||||
const string &src, bool in_src, bool show_file) {
|
||||
string old_file;
|
||||
bool file_out = false;
|
||||
bool usage_out = false;
|
||||
for (const auto &pair : usage_set) {
|
||||
const auto &file = pair.first;
|
||||
const auto &usage = pair.second;
|
||||
bool match = file == src;
|
||||
if ((match && !in_src) || (!match && in_src)) continue;
|
||||
if (file != old_file) {
|
||||
if (show_file) {
|
||||
if (file_out) cout << "\n";
|
||||
cout << "Flags from: " << file << "\n";
|
||||
file_out = true;
|
||||
}
|
||||
old_file = file;
|
||||
}
|
||||
cout << usage << "\n";
|
||||
usage_out = true;
|
||||
}
|
||||
if (usage_out) cout << "\n";
|
||||
}
|
||||
|
||||
void ShowUsage(bool long_usage) {
|
||||
std::set<pair<string, string>> usage_set;
|
||||
cout << flag_usage << "\n";
|
||||
auto bool_register = FlagRegister<bool>::GetRegister();
|
||||
bool_register->GetUsage(&usage_set);
|
||||
auto string_register = FlagRegister<string>::GetRegister();
|
||||
string_register->GetUsage(&usage_set);
|
||||
auto int32_register = FlagRegister<int32>::GetRegister();
|
||||
int32_register->GetUsage(&usage_set);
|
||||
auto int64_register = FlagRegister<int64>::GetRegister();
|
||||
int64_register->GetUsage(&usage_set);
|
||||
auto double_register = FlagRegister<double>::GetRegister();
|
||||
double_register->GetUsage(&usage_set);
|
||||
if (!prog_src.empty()) {
|
||||
cout << "PROGRAM FLAGS:\n\n";
|
||||
ShowUsageRestrict(usage_set, prog_src, true, false);
|
||||
}
|
||||
if (!long_usage) return;
|
||||
if (!prog_src.empty()) cout << "LIBRARY FLAGS:\n\n";
|
||||
ShowUsageRestrict(usage_set, prog_src, false, true);
|
||||
}
|
@ -0,0 +1,134 @@
|
||||
#include "paddle_inference_api.h"
|
||||
#include <gflags/gflags.h>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
|
||||
void produce_data(std::vector<std::vector<float>>* data);
|
||||
void model_forward_test();
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
model_forward_test();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void model_forward_test() {
|
||||
std::cout << "1. read the data" << std::endl;
|
||||
std::vector<std::vector<float>> feats;
|
||||
produce_data(&feats);
|
||||
|
||||
std::cout << "2. load the model" << std::endl;;
|
||||
std::string model_graph = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel";
|
||||
std::string model_params = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams";
|
||||
paddle_infer::Config config;
|
||||
config.SetModel(model_graph, model_params);
|
||||
config.SwitchIrOptim(false);
|
||||
config.DisableFCPadding();
|
||||
auto predictor = paddle_infer::CreatePredictor(config);
|
||||
|
||||
std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl;
|
||||
std::vector<float> paddle_input_feature_matrix;
|
||||
for(const auto& item : feats) {
|
||||
paddle_input_feature_matrix.insert(paddle_input_feature_matrix.end(), item.begin(), item.end());
|
||||
}
|
||||
|
||||
std::cout << "4. fead the data to model" << std::endl;
|
||||
int row = feats.size();
|
||||
int col = feats[0].size();
|
||||
std::vector<std::string> input_names = predictor->GetInputNames();
|
||||
std::vector<std::string> output_names = predictor->GetOutputNames();
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> input_tensor =
|
||||
predictor->GetInputHandle(input_names[0]);
|
||||
std::vector<int> INPUT_SHAPE = {1, row, col};
|
||||
input_tensor->Reshape(INPUT_SHAPE);
|
||||
input_tensor->CopyFromCpu(paddle_input_feature_matrix.data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
|
||||
std::vector<int> input_len_size = {1};
|
||||
input_len->Reshape(input_len_size);
|
||||
std::vector<int64_t> audio_len;
|
||||
audio_len.push_back(row);
|
||||
input_len->CopyFromCpu(audio_len.data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = predictor->GetInputHandle(input_names[2]);
|
||||
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
|
||||
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
|
||||
int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
|
||||
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = predictor->GetInputHandle(input_names[3]);
|
||||
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
|
||||
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
|
||||
int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
|
||||
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
|
||||
|
||||
bool success = predictor->Run();
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
|
||||
std::vector<int> h_out_shape = h_out->shape();
|
||||
int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
std::vector<float> h_out_data(h_out_size);
|
||||
h_out->CopyToCpu(h_out_data.data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
|
||||
std::vector<int> c_out_shape = c_out->shape();
|
||||
int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
std::vector<float> c_out_data(c_out_size);
|
||||
c_out->CopyToCpu(c_out_data.data());
|
||||
|
||||
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
||||
predictor->GetOutputHandle(output_names[0]);
|
||||
std::vector<int> output_shape = output_tensor->shape();
|
||||
std::vector<float> output_probs;
|
||||
int output_size = std::accumulate(output_shape.begin(), output_shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
output_probs.resize(output_size);
|
||||
output_tensor->CopyToCpu(output_probs.data());
|
||||
row = output_shape[1];
|
||||
col = output_shape[2];
|
||||
|
||||
std::vector<std::vector<float>> probs;
|
||||
probs.reserve(row);
|
||||
for (int i = 0; i < row; i++) {
|
||||
probs.push_back(std::vector<float>());
|
||||
probs.back().reserve(col);
|
||||
|
||||
for (int j = 0; j < col; j++) {
|
||||
probs.back().push_back(output_probs[i * col + j]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> log_feat = probs;
|
||||
std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl;
|
||||
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
|
||||
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) {
|
||||
std::cout << log_feat[row_idx][col_idx] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void produce_data(std::vector<std::vector<float>>* data) {
|
||||
int chunk_size = 35;
|
||||
int col_size = 161;
|
||||
data->reserve(chunk_size);
|
||||
data->back().reserve(col_size);
|
||||
for (int row = 0; row < chunk_size; ++row) {
|
||||
data->push_back(std::vector<float>());
|
||||
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
|
||||
data->back().push_back(0.201);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,10 +1,10 @@
|
||||
project(decoder)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
|
||||
add_library(decoder
|
||||
add_library(decoder STATIC
|
||||
ctc_beam_search_decoder.cc
|
||||
ctc_decoders/decoder_utils.cpp
|
||||
ctc_decoders/path_trie.cpp
|
||||
ctc_decoders/scorer.cpp
|
||||
)
|
||||
target_link_libraries(decoder kenlm)
|
||||
target_link_libraries(decoder PUBLIC kenlm utils fst)
|
@ -1,8 +1,8 @@
|
||||
project(frontend)
|
||||
|
||||
add_library(frontend
|
||||
add_library(frontend STATIC
|
||||
normalizer.cc
|
||||
linear_spectrogram.cc
|
||||
)
|
||||
|
||||
target_link_libraries(frontend kaldi-matrix)
|
||||
target_link_libraries(frontend PUBLIC kaldi-matrix)
|
@ -1,39 +0,0 @@
|
||||
# Copyright (c) 2020 PeachLab. All Rights Reserved.
|
||||
# Author : goat.zhou@qq.com (Yang Zhou)
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = 'kaldi-matrix',
|
||||
srcs = [
|
||||
'compressed-matrix.cc',
|
||||
'kaldi-matrix.cc',
|
||||
'kaldi-vector.cc',
|
||||
'matrix-functions.cc',
|
||||
'optimization.cc',
|
||||
'packed-matrix.cc',
|
||||
'qr.cc',
|
||||
'sparse-matrix.cc',
|
||||
'sp-matrix.cc',
|
||||
'srfft.cc',
|
||||
'tp-matrix.cc',
|
||||
],
|
||||
hdrs = glob(["*.h"]),
|
||||
deps = [
|
||||
'//base:kaldi-base',
|
||||
'//common/third_party/openblas:openblas',
|
||||
],
|
||||
linkopts=['-lgfortran'],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = 'matrix-lib-test',
|
||||
srcs = [
|
||||
'matrix-lib-test.cc',
|
||||
],
|
||||
deps = [
|
||||
':kaldi-matrix',
|
||||
'//util:kaldi-util',
|
||||
],
|
||||
)
|
||||
|
@ -1,2 +1,7 @@
|
||||
aux_source_directory(. DIR_LIB_SRCS)
|
||||
add_library(nnet STATIC ${DIR_LIB_SRCS})
|
||||
project(nnet)
|
||||
|
||||
add_library(nnet STATIC
|
||||
decodable.cc
|
||||
paddle_nnet.cc
|
||||
)
|
||||
target_link_libraries(nnet absl::strings)
|
Loading…
Reference in new issue