pull/2524/head
Hui Zhang 2 years ago
parent 8271fcfb0a
commit a6b2a0a697

@ -50,13 +50,20 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
#- id: copyright_checker #- id: copyright_checker
# name: copyright_checker # name: copyright_checker
# entry: python .pre-commit-hooks/copyright-check.hook # entry: python .pre-commit-hooks/copyright-check.hook
# language: system # language: system
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
# exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ # exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: cpplint
name: cpplint
description: Static code analysis of C/C++ files
language: python
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -22,39 +22,39 @@ typedef float BaseFloat;
typedef double double64; typedef double double64;
typedef signed char int8; typedef signed char int8;
typedef short int16; typedef short int16; // NOLINT
typedef int int32; typedef int int32; // NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef long int64; typedef long int64; // NOLINT
#else #else
typedef long long int64; typedef long long int64; // NOLINT
#endif #endif
typedef unsigned char uint8; typedef unsigned char uint8; // NOLINT
typedef unsigned short uint16; typedef unsigned short uint16; // NOLINT
typedef unsigned int uint32; typedef unsigned int uint32; // NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64; typedef unsigned long uint64; // NOLINT
#else #else
typedef unsigned long long uint64; typedef unsigned long long uint64; // NOLINT
#endif #endif
typedef signed int char32; typedef signed int char32;
const uint8 kuint8max = ((uint8)0xFF); const uint8 kuint8max = (static_cast<uint8> 0xFF);
const uint16 kuint16max = ((uint16)0xFFFF); const uint16 kuint16max = (static_cast<uint16> 0xFFFF);
const uint32 kuint32max = ((uint32)0xFFFFFFFF); const uint32 kuint32max = (static_cast<uint32> 0xFFFFFFFF);
const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL)); const uint64 kuint64max = (static_cast<uint64>(0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = ((int8)0x80); const int8 kint8min = (static_cast<int8> 0x80);
const int8 kint8max = ((int8)0x7F); const int8 kint8max = (static_cast<int8> 0x7F);
const int16 kint16min = ((int16)0x8000); const int16 kint16min = (static_cast<int16> 0x8000);
const int16 kint16max = ((int16)0x7FFF); const int16 kint16max = (static_cast<int16> 0x7FFF);
const int32 kint32min = ((int32)0x80000000); const int32 kint32min = (static_cast<int32> 0x80000000);
const int32 kint32max = ((int32)0x7FFFFFFF); const int32 kint32max = (static_cast<int32> 0x7FFFFFFF);
const int64 kint64min = ((int64)(0x8000000000000000LL)); const int64 kint64min = (static_cast<int64>(0x8000000000000000LL));
const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL)); const int64 kint64max = (static_cast<int64>(0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max(); const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min(); const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();

@ -26,6 +26,6 @@ namespace ppspeech {
#endif #endif
// kSpaceSymbol in UTF-8 is: ▁ // kSpaceSymbol in UTF-8 is: ▁
const std::string kSpaceSymbol = "\xe2\x96\x81"; const char[] kSpaceSymbol = "\xe2\x96\x81";
} // namespace ppspeech } // namespace ppspeech

@ -35,7 +35,7 @@
class ThreadPool { class ThreadPool {
public: public:
ThreadPool(size_t); explicit ThreadPool(size_t);
template <class F, class... Args> template <class F, class... Args>
auto enqueue(F&& f, Args&&... args) auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>; -> std::future<typename std::result_of<F(Args...)>::type>;

@ -64,8 +64,8 @@ void model_forward_test() {
; ;
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
CHECK(model_graph != ""); CHECK_NE(model_graph, "");
CHECK(model_params != ""); CHECK_NE(model_params, "");
cout << "model path: " << model_graph << endl; cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl; cout << "model param path : " << model_params << endl;

@ -39,12 +39,12 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
} }
CHECK(opts_.blank == 0); CHECK_EQ(opts_.blank, 0);
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin(); space_id_ = it - vocabulary_.begin();
// if no space in vocabulary // if no space in vocabulary
if ((size_t)space_id_ >= vocabulary_.size()) { if (static_cast<size_t>(space_id_) >= vocabulary_.size()) {
space_id_ = -2; space_id_ = -2;
} }
} }
@ -104,7 +104,7 @@ void CTCBeamSearch::ResetPrefixes() {
} }
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs, int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) { const vector<string>& nbest_words) {
kaldi::Timer timer; kaldi::Timer timer;
AdvanceDecoding(probs); AdvanceDecoding(probs);
LOG(INFO) << "ctc decoding elapsed time(s) " LOG(INFO) << "ctc decoding elapsed time(s) "

@ -48,7 +48,7 @@ class CTCBeamSearch : public DecoderBase {
} }
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); const std::vector<std::string>& nbest_words);
private: private:
void ResetPrefixes(); void ResetPrefixes();

@ -59,8 +59,8 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler(); google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
CHECK(FLAGS_result_wspecifier != ""); CHECK_NE(FLAGS_result_wspecifier, "");
CHECK(FLAGS_feature_rspecifier != ""); CHECK_NE(FLAGS_feature_rspecifier, "");
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);

@ -36,7 +36,7 @@ struct CTCBeamSearchOptions {
// u2 // u2
int first_beam_size; int first_beam_size;
int second_beam_size; int second_beam_size;
explicit CTCBeamSearchOptions() CTCBeamSearchOptions()
: blank(0), : blank(0),
dict_file("vocab.txt"), dict_file("vocab.txt"),
lm_path(""), lm_path(""),

@ -329,8 +329,8 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
std::string CTCPrefixBeamSearch::GetBestPath(int index) { std::string CTCPrefixBeamSearch::GetBestPath(int index) {
int n_hyps = Outputs().size(); int n_hyps = Outputs().size();
CHECK(n_hyps > 0); CHECK_GT(n_hyps, 0);
CHECK(index < n_hyps); CHECK_LT(index, n_hyps);
std::vector<int> one = Outputs()[index]; std::vector<int> one = Outputs()[index];
std::string sentence; std::string sentence;
for (int i = 0; i < one.size(); i++) { for (int i = 0; i < one.size(); i++) {
@ -344,7 +344,7 @@ std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); }
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath( std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath(
int n) { int n) {
int hyps_size = hypotheses_.size(); int hyps_size = hypotheses_.size();
CHECK(hyps_size > 0); CHECK_GT(hyps_size, 0);
int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size); int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size);

@ -28,7 +28,7 @@ class ContextGraph;
class CTCPrefixBeamSearch : public DecoderBase { class CTCPrefixBeamSearch : public DecoderBase {
public: public:
CTCPrefixBeamSearch(const std::string& vocab_path, CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts); const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {} ~CTCPrefixBeamSearch() {}
SearchType Type() const { return SearchType::kPrefixBeamSearch; } SearchType Type() const { return SearchType::kPrefixBeamSearch; }

@ -50,10 +50,10 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
CHECK(FLAGS_result_wspecifier != ""); CHECK_NE(FLAGS_result_wspecifier, "");
CHECK(FLAGS_feature_rspecifier != ""); CHECK_NE(FLAGS_feature_rspecifier, "");
CHECK(FLAGS_vocab_path != ""); CHECK_NE(FLAGS_vocab_path, "");
CHECK(FLAGS_model_path != ""); CHECK_NE(FLAGS_model_path, "");
LOG(INFO) << "model path: " << FLAGS_model_path; LOG(INFO) << "model path: " << FLAGS_model_path;
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path; LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
@ -64,11 +64,14 @@ int main(int argc, char* argv[]) {
// nnet // nnet
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path; model_opts.model_path = FLAGS_model_path;
std::shared_ptr<ppspeech::U2Nnet> nnet = std::make_shared<ppspeech::U2Nnet>(model_opts); std::shared_ptr<ppspeech::U2Nnet> nnet =
std::make_shared<ppspeech::U2Nnet>(model_opts);
// decodeable // decodeable
std::shared_ptr<ppspeech::DataCache> raw_data = std::make_shared<ppspeech::DataCache>(); std::shared_ptr<ppspeech::DataCache> raw_data =
std::shared_ptr<ppspeech::Decodable> decodable = std::make_shared<ppspeech::Decodable>(nnet, raw_data); std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet, raw_data);
// decoder // decoder
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;

@ -71,7 +71,7 @@ class TLGDecoder : public DecoderBase {
std::string GetPartialResult() override; std::string GetPartialResult() override;
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); const std::vector<std::string>& nbest_words);
protected: protected:
std::string GetBestPath() override { std::string GetBestPath() override {

@ -30,7 +30,7 @@ using std::vector;
CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor) CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) { : var_norm_(true) {
CHECK(cmvn_file != ""); CHECK_NE(cmvn_file, "");
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
bool binary; bool binary;

@ -40,8 +40,8 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler(); google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
CHECK(FLAGS_wav_rspecifier.size() > 0); CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
CHECK(FLAGS_feature_wspecifier.size() > 0); CHECK_GT(FLAGS_feature_wspecifier.size(), 0);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier); FLAGS_wav_rspecifier);
kaldi::SequentialTableReader<kaldi::WaveInfoHolder> wav_info_reader( kaldi::SequentialTableReader<kaldi::WaveInfoHolder> wav_info_reader(

@ -27,7 +27,7 @@ namespace ppspeech {
// pre-recorded audio/feature // pre-recorded audio/feature
class DataCache : public FrontendInterface { class DataCache : public FrontendInterface {
public: public:
explicit DataCache() { finished_ = false; } DataCache() { finished_ = false; }
// accept waves/feats // accept waves/feats
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) { virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {

@ -14,17 +14,18 @@
#include "frontend/audio/db_norm.h" #include "frontend/audio/db_norm.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
DecibelNormalizer::DecibelNormalizer( DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts, const DecibelNormalizerOptions& opts,

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "frontend/audio/fbank.h" #include "frontend/audio/fbank.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -20,12 +21,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
FbankComputer::FbankComputer(const Options& opts) FbankComputer::FbankComputer(const Options& opts)

@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
opts.linear_spectrogram_opts, std::move(data_source))); opts.linear_spectrogram_opts, std::move(data_source)));
} }
CHECK(opts.cmvn_file != ""); CHECK_NE(opts.cmvn_file, "");
unique_ptr<FrontendInterface> cmvn( unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -20,12 +21,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts) LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts)

@ -14,6 +14,7 @@
#include "frontend/audio/mfcc.h" #include "frontend/audio/mfcc.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -21,12 +22,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
Mfcc::Mfcc(const MfccOptions& opts, Mfcc::Mfcc(const MfccOptions& opts,

@ -13,15 +13,16 @@
// limitations under the License. // limitations under the License.
#include "nnet/ds2_nnet.h" #include "nnet/ds2_nnet.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
namespace ppspeech { namespace ppspeech {
using std::vector;
using std::string;
using std::shared_ptr;
using kaldi::Matrix; using kaldi::Matrix;
using kaldi::Vector; using kaldi::Vector;
using std::shared_ptr;
using std::string;
using std::vector;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names; std::vector<std::string> cache_names;
@ -207,7 +208,7 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
// inferences->Resize(row * col); // inferences->Resize(row * col);
// *inference_dim = col; // *inference_dim = col;
out->logprobs.Resize(row*col); out->logprobs.Resize(row * col);
out->vocab_dim = col; out->vocab_dim = col;
output_tensor->CopyToCpu(out->logprobs.Data()); output_tensor->CopyToCpu(out->logprobs.Data());

@ -26,7 +26,7 @@ template <typename T>
class Tensor { class Tensor {
public: public:
Tensor() {} Tensor() {}
Tensor(const std::vector<int>& shape) : _shape(shape) { explicit Tensor(const std::vector<int>& shape) : _shape(shape) {
int neml = std::accumulate( int neml = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>()); _shape.begin(), _shape.end(), 1, std::multiplies<int>());
LOG(INFO) << "Tensor neml: " << neml; LOG(INFO) << "Tensor neml: " << neml;
@ -50,7 +50,7 @@ class Tensor {
class PaddleNnet : public NnetBase { class PaddleNnet : public NnetBase {
public: public:
PaddleNnet(const ModelOptions& opts); explicit PaddleNnet(const ModelOptions& opts);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
const int32& feature_dim, const int32& feature_dim,

@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "nnet/ds2_nnet.h"
#include "base/common.h" #include "base/common.h"
#include "decoder/param.h" #include "decoder/param.h"
#include "frontend/audio/assembler.h" #include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
@ -44,7 +44,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));

@ -158,7 +158,7 @@ void U2Nnet::Reset() {
} }
// Debug API // Debug API
void U2Nnet::FeedEncoderOuts(paddle::Tensor& encoder_out) { void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) {
// encoder_out (T,D) // encoder_out (T,D)
encoder_outs_.clear(); encoder_outs_.clear();
encoder_outs_.push_back(encoder_out); encoder_outs_.push_back(encoder_out);
@ -206,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
float* feats_ptr = feats.mutable_data<float>(); float* feats_ptr = feats.mutable_data<float>();
// not cache feature in nnet // not cache feature in nnet
CHECK(cached_feats_.size() == 0); CHECK_EQ(cached_feats_.size(), 0);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true); // CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
std::memcpy(feats_ptr, std::memcpy(feats_ptr,
chunk_feats.data(), chunk_feats.data(),
@ -247,9 +247,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
// call. // call.
std::vector<paddle::Tensor> inputs = { std::vector<paddle::Tensor> inputs = {
feats, offset, /*required_cache_size, */ att_cache_, cnn_cache_}; feats, offset, /*required_cache_size, */ att_cache_, cnn_cache_};
CHECK(inputs.size() == 4); CHECK_EQ(inputs.size(), 4);
std::vector<paddle::Tensor> outputs = forward_encoder_chunk_(inputs); std::vector<paddle::Tensor> outputs = forward_encoder_chunk_(inputs);
CHECK(outputs.size() == 3); CHECK_EQ(outputs.size(), 3);
#ifdef USE_GPU #ifdef USE_GPU
paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace()); paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace());
@ -319,9 +319,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
inputs.clear(); inputs.clear();
outputs.clear(); outputs.clear();
inputs.push_back(chunk_out); inputs.push_back(chunk_out);
CHECK(inputs.size() == 1); CHECK_EQ(inputs.size(), 1);
outputs = ctc_activation_(inputs); outputs = ctc_activation_(inputs);
CHECK(outputs.size() == 1); CHECK_EQ(outputs.size(), 1);
paddle::Tensor ctc_log_probs = outputs[0]; paddle::Tensor ctc_log_probs = outputs[0];
#ifdef TEST_DEBUG #ifdef TEST_DEBUG
@ -350,9 +350,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
// Copy to output, (B=1,T,D) // Copy to output, (B=1,T,D)
std::vector<int64_t> ctc_log_probs_shape = ctc_log_probs.shape(); std::vector<int64_t> ctc_log_probs_shape = ctc_log_probs.shape();
CHECK(ctc_log_probs_shape.size() == 3); CHECK_EQ(ctc_log_probs_shape.size(), 3);
int B = ctc_log_probs_shape[0]; int B = ctc_log_probs_shape[0];
CHECK(B == 1); CHECK_EQ(B, 1);
int T = ctc_log_probs_shape[1]; int T = ctc_log_probs_shape[1];
int D = ctc_log_probs_shape[2]; int D = ctc_log_probs_shape[2];
*vocab_dim = D; *vocab_dim = D;
@ -393,9 +393,9 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob,
// hyp (U,) // hyp (U,)
float score = 0.0f; float score = 0.0f;
std::vector<int64_t> dims = prob.shape(); std::vector<int64_t> dims = prob.shape();
CHECK(dims.size() == 3); CHECK_EQ(dims.size(), 3);
VLOG(2) << "prob shape: " << dims[0] << ", " << dims[1] << ", " << dims[2]; VLOG(2) << "prob shape: " << dims[0] << ", " << dims[1] << ", " << dims[2];
CHECK(dims[0] == 1); CHECK_EQ(dims[0], 1);
int vocab_dim = static_cast<int>(dims[2]); int vocab_dim = static_cast<int>(dims[2]);
const float* prob_ptr = prob.data<float>(); const float* prob_ptr = prob.data<float>();
@ -520,14 +520,14 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
std::vector<paddle::experimental::Tensor> inputs{ std::vector<paddle::experimental::Tensor> inputs{
hyps_tensor, hyps_lens, encoder_out}; hyps_tensor, hyps_lens, encoder_out};
std::vector<paddle::Tensor> outputs = forward_attention_decoder_(inputs); std::vector<paddle::Tensor> outputs = forward_attention_decoder_(inputs);
CHECK(outputs.size() == 2); CHECK_EQ(outputs.size(), 2);
// (B, Umax, V) // (B, Umax, V)
paddle::Tensor probs = outputs[0]; paddle::Tensor probs = outputs[0];
std::vector<int64_t> probs_shape = probs.shape(); std::vector<int64_t> probs_shape = probs.shape();
CHECK(probs_shape.size() == 3); CHECK_EQ(probs_shape.size(), 3);
CHECK(probs_shape[0] == num_hyps); CHECK_EQ(probs_shape[0], num_hyps);
CHECK(probs_shape[1] == max_hyps_len); CHECK_EQ(probs_shape[1], max_hyps_len);
#ifdef TEST_DEBUG #ifdef TEST_DEBUG
{ {
@ -582,13 +582,13 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
paddle::Tensor r_probs = outputs[1]; paddle::Tensor r_probs = outputs[1];
std::vector<int64_t> r_probs_shape = r_probs.shape(); std::vector<int64_t> r_probs_shape = r_probs.shape();
if (is_bidecoder_ && reverse_weight > 0) { if (is_bidecoder_ && reverse_weight > 0) {
CHECK(r_probs_shape.size() == 3); CHECK_EQ(r_probs_shape.size(), 3);
CHECK(r_probs_shape[0] == num_hyps); CHECK_EQ(r_probs_shape[0], num_hyps);
CHECK(r_probs_shape[1] == max_hyps_len); CHECK_EQ(r_probs_shape[1], max_hyps_len);
} else { } else {
// dump r_probs // dump r_probs
CHECK(r_probs_shape.size() == 1); CHECK_EQ(r_probs_shape.size(), 1);
CHECK(r_probs_shape[0] == 1) << r_probs_shape[0]; CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
} }
// compute rescoring score // compute rescoring score
@ -644,7 +644,7 @@ void U2Nnet::EncoderOuts(
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
const paddle::Tensor& item = encoder_outs_[i]; const paddle::Tensor& item = encoder_outs_[i];
const std::vector<int64_t> shape = item.shape(); const std::vector<int64_t> shape = item.shape();
CHECK(shape.size() == 3); CHECK_EQ(shape.size(), 3);
const int& B = shape[0]; const int& B = shape[0];
const int& T = shape[1]; const int& T = shape[1];
const int& D = shape[2]; const int& D = shape[2];

@ -73,7 +73,7 @@ class U2NnetBase : public NnetBase {
class U2Nnet : public U2NnetBase { class U2Nnet : public U2NnetBase {
public: public:
U2Nnet(const ModelOptions& opts); explicit U2Nnet(const ModelOptions& opts);
U2Nnet(const U2Nnet& other); U2Nnet(const U2Nnet& other);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
@ -108,7 +108,7 @@ class U2Nnet : public U2NnetBase {
std::vector<float>* rescoring_score) override; std::vector<float>* rescoring_score) override;
// debug // debug
void FeedEncoderOuts(paddle::Tensor& encoder_out); void FeedEncoderOuts(const paddle::Tensor& encoder_out);
void EncoderOuts( void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const; std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const;

@ -39,9 +39,9 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
CHECK(FLAGS_feature_rspecifier.size() > 0); CHECK_GT(FLAGS_feature_rspecifier.size(), 0);
CHECK(FLAGS_nnet_prob_wspecifier.size() > 0); CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
CHECK(FLAGS_model_path.size() > 0); CHECK_GT(FLAGS_model_path.size(), 0);
LOG(INFO) << "input rspecifier: " << FLAGS_feature_rspecifier; LOG(INFO) << "input rspecifier: " << FLAGS_feature_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path; LOG(INFO) << "model path: " << FLAGS_model_path;

@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "websocket/websocket_client.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "websocket/websocket_client.h"
DEFINE_string(host, "127.0.0.1", "host of websocket server"); DEFINE_string(host, "127.0.0.1", "host of websocket server");
DEFINE_int32(port, 8082, "port of websocket server"); DEFINE_int32(port, 8082, "port of websocket server");

@ -39,7 +39,8 @@ struct RecognizerResource {
resource.feature_pipeline_opts = resource.feature_pipeline_opts =
FeaturePipelineOptions::InitFromFlags(); FeaturePipelineOptions::InitFromFlags();
resource.feature_pipeline_opts.assembler_opts.fill_zero = true; resource.feature_pipeline_opts.assembler_opts.fill_zero = true;
LOG(INFO) << "ds2 need fill zero be true: " << resource.feature_pipeline_opts.assembler_opts.fill_zero; LOG(INFO) << "ds2 need fill zero be true: "
<< resource.feature_pipeline_opts.assembler_opts.fill_zero;
resource.model_opts = ModelOptions::InitFromFlags(); resource.model_opts = ModelOptions::InitFromFlags();
resource.tlg_opts = TLGDecoderOptions::InitFromFlags(); resource.tlg_opts = TLGDecoderOptions::InitFromFlags();
return resource; return resource;

@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#include "decoder/param.h" #include "decoder/param.h"
#include "recognizer/recognizer.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "recognizer/recognizer.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
@ -30,7 +30,8 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler(); google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
ppspeech::RecognizerResource resource = ppspeech::RecognizerResource::InitFromFlags(); ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
ppspeech::Recognizer recognizer(resource); ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(

@ -35,7 +35,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
BaseFloat am_scale = resource.acoustic_scale; BaseFloat am_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
CHECK(resource.vocab_path != ""); CHECK_NE(resource.vocab_path, "");
decoder_.reset(new CTCPrefixBeamSearch( decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));

@ -1,5 +1,3 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");

@ -40,4 +40,4 @@ std::string ReadFile2String(const std::string& path) {
return std::string((std::istreambuf_iterator<char>(input_file)), return std::string((std::istreambuf_iterator<char>(input_file)),
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
} }
} } // namespace ppspeech

Loading…
Cancel
Save