pull/2524/head
Hui Zhang 2 years ago
parent 72c9e973a2
commit fddcd36fa0

@ -25,4 +25,4 @@ run.sh --stop_stage 0
``` ```
./run.sh --stage 3 --stop_stage 3 ./run.sh --stage 3 --stop_stage 3
``` ```

@ -21,6 +21,7 @@
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"

@ -13,9 +13,10 @@
// limitations under the License. // limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_decoders/decoder_utils.h" #include "decoder/ctc_decoders/decoder_utils.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
namespace ppspeech { namespace ppspeech {
@ -24,10 +25,7 @@ using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts), : opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) {
init_ext_scorer_(nullptr),
space_id_(-1),
root_(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file; LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed"; LOG(INFO) << "load the dict failed";
@ -41,7 +39,7 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
} }
CHECK(opts_.blank==0); CHECK(opts_.blank == 0);
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin(); space_id_ = it - vocabulary_.begin();
@ -115,7 +113,7 @@ int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
} }
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath(int n) { vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath(int n) {
int beam_size = n == -1 ? opts_.beam_size: std::min(n, opts_.beam_size); int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size);
return get_beam_search_result(prefixes_, vocabulary_, beam_size); return get_beam_search_result(prefixes_, vocabulary_, beam_size);
} }

@ -16,11 +16,12 @@
#include "decoder/ctc_prefix_beam_search_decoder.h" #include "decoder/ctc_prefix_beam_search_decoder.h"
#include "absl/strings/str_join.h"
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h" #include "decoder/ctc_prefix_beam_search_score.h"
#include "utils/math.h" #include "utils/math.h"
#include "absl/strings/str_join.h"
#ifdef USE_PROFILING #ifdef USE_PROFILING
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
@ -30,18 +31,17 @@ using paddle::platform::TracerEventType;
namespace ppspeech { namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch( CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string vocab_path,
const std::string vocab_path, const CTCBeamSearchOptions& opts)
const CTCBeamSearchOptions& opts)
: opts_(opts) { : opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>(
unit_table_ = std::shared_ptr<fst::SymbolTable>(fst::SymbolTable::ReadText(vocab_path)); fst::SymbolTable::ReadText(vocab_path));
CHECK(unit_table_ != nullptr); CHECK(unit_table_ != nullptr);
Reset(); Reset();
} }
void CTCPrefixBeamSearch::Reset() { void CTCPrefixBeamSearch::Reset() {
num_frame_decoded_ = 0; num_frame_decoded_ = 0;
cur_hyps_.clear(); cur_hyps_.clear();
@ -65,10 +65,9 @@ void CTCPrefixBeamSearch::Reset() {
hypotheses_.emplace_back(empty); hypotheses_.emplace_back(empty);
likelihood_.emplace_back(prefix_score.TotalScore()); likelihood_.emplace_back(prefix_score.TotalScore());
times_.emplace_back(empty); times_.emplace_back(empty);
} }
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode( void CTCPrefixBeamSearch::AdvanceDecode(
@ -296,9 +295,7 @@ void CTCPrefixBeamSearch::UpdateOutputs(
outputs_.emplace_back(output); outputs_.emplace_back(output);
} }
void CTCPrefixBeamSearch::FinalizeSearch() { void CTCPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); }
UpdateFinalContext();
}
void CTCPrefixBeamSearch::UpdateFinalContext() { void CTCPrefixBeamSearch::UpdateFinalContext() {
if (context_graph_ == nullptr) return; if (context_graph_ == nullptr) return;
@ -311,8 +308,8 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
for (const auto& prefix : hypotheses_) { for (const auto& prefix : hypotheses_) {
PrefixScore& prefix_score = cur_hyps_[prefix]; PrefixScore& prefix_score = cur_hyps_[prefix];
if (prefix_score.context_score != 0) { if (prefix_score.context_score != 0) {
prefix_score.UpdateContext(context_graph_, prefix_score, 0, prefix_score.UpdateContext(
prefix.size()); context_graph_, prefix_score, 0, prefix.size());
} }
} }
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(), std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(),
@ -323,48 +320,44 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
UpdateHypotheses(arr); UpdateHypotheses(arr);
} }
std::string CTCPrefixBeamSearch::GetBestPath(int index) { std::string CTCPrefixBeamSearch::GetBestPath(int index) {
int n_hyps = Outputs().size(); int n_hyps = Outputs().size();
CHECK(n_hyps > 0); CHECK(n_hyps > 0);
CHECK(index < n_hyps); CHECK(index < n_hyps);
std::vector<int> one = Outputs()[index]; std::vector<int> one = Outputs()[index];
std::string sentence; std::string sentence;
for (int i = 0; i < one.size(); i++){ for (int i = 0; i < one.size(); i++) {
sentence += unit_table_->Find(one[i]); sentence += unit_table_->Find(one[i]);
} }
return sentence; return sentence;
} }
std::string CTCPrefixBeamSearch::GetBestPath() { std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); }
return GetBestPath(0);
}
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath(int n) { std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath(
int hyps_size = hypotheses_.size(); int n) {
CHECK(hyps_size > 0); int hyps_size = hypotheses_.size();
CHECK(hyps_size > 0);
int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size); int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size);
std::vector<std::pair<double, std::string>> n_best; std::vector<std::pair<double, std::string>> n_best;
n_best.reserve(min_n); n_best.reserve(min_n);
for (int i = 0; i < min_n; i++){ for (int i = 0; i < min_n; i++) {
n_best.emplace_back(Likelihood()[i], GetBestPath(i) ); n_best.emplace_back(Likelihood()[i], GetBestPath(i));
} }
return n_best; return n_best;
} }
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath() { std::vector<std::pair<double, std::string>>
CTCPrefixBeamSearch::GetNBestPath() {
return GetNBestPath(-1); return GetNBestPath(-1);
}
std::string CTCPrefixBeamSearch::GetFinalBestPath() {
return GetBestPath();
} }
std::string CTCPrefixBeamSearch::GetPartialResult() { std::string CTCPrefixBeamSearch::GetFinalBestPath() { return GetBestPath(); }
return GetBestPath();
} std::string CTCPrefixBeamSearch::GetPartialResult() { return GetBestPath(); }
} // namespace ppspeech } // namespace ppspeech

@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/strings/str_split.h"
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h" #include "decoder/ctc_prefix_beam_search_decoder.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "fst/symbol-table.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/u2_nnet.h" #include "nnet/u2_nnet.h"
#include "absl/strings/str_split.h"
#include "fst/symbol-table.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
@ -64,8 +64,7 @@ int main(int argc, char* argv[]) {
// nnet // nnet
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path; model_opts.model_path = FLAGS_model_path;
std::shared_ptr<ppspeech::U2Nnet> nnet( std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
new ppspeech::U2Nnet(model_opts));
// decodeable // decodeable
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache()); std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
@ -114,9 +113,9 @@ int main(int argc, char* argv[]) {
ori_feature_len - chunk_idx * chunk_stride, chunk_size); ori_feature_len - chunk_idx * chunk_stride, chunk_size);
} }
if (this_chunk_size < receptive_field_length) { if (this_chunk_size < receptive_field_length) {
LOG(WARNING) << "utt: " << utt << " skip last " LOG(WARNING)
<< this_chunk_size << " frames, expect is " << "utt: " << utt << " skip last " << this_chunk_size
<< receptive_field_length; << " frames, expect is " << receptive_field_length;
break; break;
} }
@ -127,7 +126,7 @@ int main(int argc, char* argv[]) {
for (int row_id = 0; row_id < this_chunk_size; ++row_id) { for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start); kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row( kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim); feature_chunk.Data() + row_id * feat_dim, feat_dim);
feature_chunk_row.CopyFromVec(feat_row); feature_chunk_row.CopyFromVec(feat_row);
++start; ++start;
@ -151,7 +150,7 @@ int main(int argc, char* argv[]) {
// get 1-best result // get 1-best result
std::string result = decoder.GetFinalBestPath(); std::string result = decoder.GetFinalBestPath();
// after process one utt, then reset state. // after process one utt, then reset state.
decodable->Reset(); decodable->Reset();
decoder.Reset(); decoder.Reset();

@ -33,9 +33,7 @@ void TLGDecoder::Reset() {
return; return;
} }
void TLGDecoder::InitDecoder() { void TLGDecoder::InitDecoder() { Reset(); }
Reset();
}
void TLGDecoder::AdvanceDecode( void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
@ -50,7 +48,6 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
} }
std::string TLGDecoder::GetPartialResult() { std::string TLGDecoder::GetPartialResult() {
if (num_frame_decoded_ == 0) { if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
@ -93,4 +90,4 @@ std::string TLGDecoder::GetFinalBestPath() {
return words; return words;
} }
} } // namespace ppspeech

@ -15,14 +15,12 @@
// todo refactor, repalce with gtest // todo refactor, repalce with gtest
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/param.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/ds2_nnet.h" #include "nnet/ds2_nnet.h"
#include "decoder/param.h"
#include "decoder/ctc_tlg_decoder.h"
#include "kaldi/util/table-types.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
@ -47,12 +45,13 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts = ppspeech::TLGDecoderOptions::InitFromFlags(); ppspeech::TLGDecoderOptions opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
opts.opts.beam = 15.0; opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5; opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts); ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
@ -67,7 +66,7 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length; LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder(); decoder.InitDecoder();
kaldi::Timer timer; kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {

@ -17,12 +17,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
Recognizer::Recognizer(const RecognizerResource& resource) { Recognizer::Recognizer(const RecognizerResource& resource) {

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "decoder/recognizer.h"
#include "decoder/param.h" #include "decoder/param.h"
#include "decoder/recognizer.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
@ -25,8 +25,9 @@ DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::RecognizerResource InitRecognizerResoure() { ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource; ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale; resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags(); resource.feature_pipeline_opts =
resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
return resource; return resource;
} }

@ -13,18 +13,20 @@
// limitations under the License. // limitations under the License.
#include "decoder/u2_recognizer.h" #include "decoder/u2_recognizer.h"
#include "nnet/u2_nnet.h" #include "nnet/u2_nnet.h"
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource) { U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
: opts_(resource) {
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts)); feature_pipeline_.reset(new FeaturePipeline(feature_opts));
@ -34,7 +36,8 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
CHECK(resource.vocab_path != ""); CHECK(resource.vocab_path != "");
decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
unit_table_ = decoder_->VocabTable(); unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_; symbol_table_ = unit_table_;
@ -70,140 +73,140 @@ void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
} }
void U2Recognizer::Decode() { void U2Recognizer::Decode() {
decoder_->AdvanceDecode(decodable_); decoder_->AdvanceDecode(decodable_);
UpdateResult(false); UpdateResult(false);
} }
void U2Recognizer::Rescoring() { void U2Recognizer::Rescoring() {
// Do attention Rescoring // Do attention Rescoring
kaldi::Timer timer; kaldi::Timer timer;
AttentionRescoring(); AttentionRescoring();
VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec."; VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec.";
} }
void U2Recognizer::UpdateResult(bool finish) { void U2Recognizer::UpdateResult(bool finish) {
const auto& hypotheses = decoder_->Outputs(); const auto& hypotheses = decoder_->Outputs();
const auto& inputs = decoder_->Inputs(); const auto& inputs = decoder_->Inputs();
const auto& likelihood = decoder_->Likelihood(); const auto& likelihood = decoder_->Likelihood();
const auto& times = decoder_->Times(); const auto& times = decoder_->Times();
result_.clear(); result_.clear();
CHECK_EQ(hypotheses.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i];
DecodeResult path;
path.score = likelihood[i];
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (decoder_->Type() == kWfstBeamSearch) {
path.sentence += (" " + word);
} else {
path.sentence += (word);
}
}
// TimeStamp is only supported in final result CHECK_EQ(hypotheses.size(), likelihood.size());
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to for (size_t i = 0; i < hypotheses.size(); i++) {
// various FST operations when building the decoding graph. So here we use const std::vector<int>& hypothesis = hypotheses[i];
// time stamp of the input(e2e model unit), which is more accurate, and it
// requires the symbol table of the e2e model used in training. DecodeResult path;
if (unit_table_ != nullptr && finish) { path.score = likelihood[i];
int offset = global_frame_offset_ * FrameShiftInMs(); for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (decoder_->Type() == kWfstBeamSearch) {
path.sentence += (" " + word);
} else {
path.sentence += (word);
}
}
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we
// use time stamp of the input(e2e model unit), which is more accurate,
// and it requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
int offset = global_frame_offset_ * FrameShiftInMs();
const std::vector<int>& input = inputs[i]; const std::vector<int>& input = inputs[i];
const std::vector<int> time_stamp = times[i]; const std::vector<int> time_stamp = times[i];
CHECK_EQ(input.size(), time_stamp.size()); CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) { for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]); std::string word = unit_table_->Find(input[j]);
int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 int start =
time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
? time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ ? time_stamp[j] * FrameShiftInMs() - time_stamp_gap_
: 0; : 0;
if (j > 0) { if (j > 0) {
start = (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() < start =
time_stamp_gap_ (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() <
? (time_stamp[j - 1] + time_stamp[j]) / 2 * time_stamp_gap_
FrameShiftInMs() ? (time_stamp[j - 1] + time_stamp[j]) / 2 *
: start; FrameShiftInMs()
: start;
}
int end = time_stamp[j] * FrameShiftInMs();
if (j < input.size() - 1) {
end =
(time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() <
time_stamp_gap_
? (time_stamp[j + 1] + time_stamp[j]) / 2 *
FrameShiftInMs()
: end;
}
WordPiece word_piece(word, offset + start, offset + end);
path.word_pieces.emplace_back(word_piece);
}
} }
int end = time_stamp[j] * FrameShiftInMs(); // if (post_processor_ != nullptr) {
if (j < input.size() - 1) { // path.sentence = post_processor_->Process(path.sentence, finish);
end = (time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() < // }
time_stamp_gap_
? (time_stamp[j + 1] + time_stamp[j]) / 2 *
FrameShiftInMs()
: end;
}
WordPiece word_piece(word, offset + start, offset + end); result_.emplace_back(path);
path.word_pieces.emplace_back(word_piece);
}
} }
// if (post_processor_ != nullptr) { if (DecodedSomething()) {
// path.sentence = post_processor_->Process(path.sentence, finish); VLOG(1) << "Partial CTC result " << result_[0].sentence;
// } }
result_.emplace_back(path);
}
if (DecodedSomething()) {
VLOG(1) << "Partial CTC result " << result_[0].sentence;
}
} }
void U2Recognizer::AttentionRescoring() { void U2Recognizer::AttentionRescoring() {
decoder_->FinalizeSearch(); decoder_->FinalizeSearch();
UpdateResult(true); UpdateResult(true);
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
kaldi::Timer timer;
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec.";
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score;
result_[i].score = opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(1) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::string U2Recognizer::GetFinalResult() { // No need to do rescoring
return result_[0].sentence; if (0.0 == opts_.decoder_opts.rescoring_weight) {
} LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
kaldi::Timer timer;
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec.";
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
}
std::string U2Recognizer::GetPartialResult() { std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
return result_[0].sentence; VLOG(1) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
} }
std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::SetFinished() { void U2Recognizer::SetFinished() {
feature_pipeline_->SetFinished(); feature_pipeline_->SetFinished();
input_finished_ = true; input_finished_ = true;

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "decoder/u2_recognizer.h"
#include "decoder/param.h" #include "decoder/param.h"
#include "decoder/u2_recognizer.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
@ -43,7 +43,8 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size; LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource = ppspeech::U2RecognizerResource::InitFromFlags(); ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource); ppspeech::U2Recognizer recognizer(resource);
kaldi::Timer timer; kaldi::Timer timer;
@ -103,7 +104,7 @@ int main(int argc, char* argv[]) {
} }
double elapsed = timer.Elapsed(); double elapsed = timer.Elapsed();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "cost:" << elapsed << " sec"; LOG(INFO) << "cost:" << elapsed << " sec";
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";

@ -14,17 +14,18 @@
#include "frontend/audio/cmvn.h" #include "frontend/audio/cmvn.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor) CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)
@ -57,7 +58,7 @@ bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
// feats contain num_frames feature. // feats contain num_frames feature.
void CMVN::Compute(VectorBase<BaseFloat>* feats) const { void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
KALDI_ASSERT(feats != NULL); KALDI_ASSERT(feats != NULL);
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
feats->Dim() % dim_ != 0) { feats->Dim() % dim_ != 0) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << ',' KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << ','

@ -16,16 +16,15 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h" #include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h" #include "frontend/audio/fbank.h"
#include "frontend/audio/feature_cache.h" #include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
@ -86,24 +85,27 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (sec): " << streaming_chunk; LOG(INFO) << "chunk size (sec): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size; LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done() && !wav_info_reader.Done(); wav_reader.Next(), wav_info_reader.Next()) { for (; !wav_reader.Done() && !wav_info_reader.Done();
wav_reader.Next(), wav_info_reader.Next()) {
const std::string& utt = wav_reader.Key(); const std::string& utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value(); const kaldi::WaveData& wave_data = wav_reader.Value();
const std::string& utt2 = wav_info_reader.Key(); const std::string& utt2 = wav_info_reader.Key();
const kaldi::WaveInfo& wave_info = wav_info_reader.Value(); const kaldi::WaveInfo& wave_info = wav_info_reader.Value();
CHECK(utt == utt2) << "wav reader and wav info reader using diff rspecifier!!!"; CHECK(utt == utt2)
<< "wav reader and wav info reader using diff rspecifier!!!";
LOG(INFO) << "utt: " << utt; LOG(INFO) << "utt: " << utt;
LOG(INFO) << "samples: " << wave_info.SampleCount(); LOG(INFO) << "samples: " << wave_info.SampleCount();
LOG(INFO) << "dur: " << wave_info.Duration() << " sec"; LOG(INFO) << "dur: " << wave_info.Duration() << " sec";
CHECK(wave_info.SampFreq() == FLAGS_sample_rate) << "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq(); CHECK(wave_info.SampFreq() == FLAGS_sample_rate)
<< "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq();
// load first channel wav // load first channel wav
int32 this_channel = 0; int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel); this_channel);
// compute feat chunk by chunk // compute feat chunk by chunk
int tot_samples = waveform.Dim(); int tot_samples = waveform.Dim();
int sample_offset = 0; int sample_offset = 0;
@ -157,7 +159,8 @@ int main(int argc, char* argv[]) {
++cur_idx; ++cur_idx;
} }
} }
LOG(INFO) << "feat shape: " << features.NumRows() << " , " << features.NumCols(); LOG(INFO) << "feat shape: " << features.NumRows() << " , "
<< features.NumCols();
feat_writer.Write(utt, features); feat_writer.Write(utt, features);
// reset frontend pipeline state // reset frontend pipeline state

@ -14,16 +14,15 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h" #include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h" #include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier");

@ -18,7 +18,8 @@ namespace ppspeech {
using std::unique_ptr; using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opts) { FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
: opts_(opts) {
unique_ptr<FrontendInterface> data_source( unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); new ppspeech::AudioCache(1000 * kint16max, opts.to_float32));
@ -43,4 +44,4 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opt
new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); new ppspeech::Assembler(opts.assembler_opts, std::move(cache)));
} }
} // ppspeech } // namespace ppspeech

@ -18,8 +18,8 @@ namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector;
using kaldi::Vector; using kaldi::Vector;
using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet, Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend, const std::shared_ptr<FrontendInterface>& frontend,
@ -56,7 +56,6 @@ int32 Decodable::NumIndices() const { return 0; }
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; } int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {
// decoding frame // decoding frame
if (frame >= frames_ready_) { if (frame >= frames_ready_) {
@ -92,14 +91,15 @@ bool Decodable::AdvanceChunk() {
return true; return true;
} }
bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs, int* vocab_dim) { bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
int* vocab_dim) {
if (AdvanceChunk() == false) { if (AdvanceChunk() == false) {
return false; return false;
} }
int nrows = nnet_out_cache_.NumRows(); int nrows = nnet_out_cache_.NumRows();
CHECK(nrows == (frames_ready_ - frame_offset_)); CHECK(nrows == (frames_ready_ - frame_offset_));
if (nrows <= 0){ if (nrows <= 0) {
LOG(WARNING) << "No new nnet out in cache."; LOG(WARNING) << "No new nnet out in cache.";
return false; return false;
} }
@ -107,7 +107,7 @@ bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs, int* voc
logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols()); logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols());
logprobs->CopyRowsFromMat(nnet_out_cache_); logprobs->CopyRowsFromMat(nnet_out_cache_);
*vocab_dim = nnet_out_cache_.NumCols(); *vocab_dim = nnet_out_cache_.NumCols();
return true; return true;
} }
@ -140,7 +140,7 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
BaseFloat logprob = 0.0; BaseFloat logprob = 0.0;
int32 frame_idx = frame - frame_offset_; int32 frame_idx = frame - frame_offset_;
BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index)); BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index));
if (nnet_->IsLogProb()){ if (nnet_->IsLogProb()) {
logprob = nnet_out; logprob = nnet_out;
} else { } else {
logprob = std::log(nnet_out + std::numeric_limits<float>::epsilon()); logprob = std::log(nnet_out + std::numeric_limits<float>::epsilon());
@ -158,8 +158,8 @@ void Decodable::Reset() {
} }
void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps, void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight, float reverse_weight,
std::vector<float>* rescoring_score){ std::vector<float>* rescoring_score) {
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
} }

@ -242,7 +242,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
const int32& feat_dim, const int32& feat_dim,
std::vector<kaldi::BaseFloat>* out_prob, std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) { int32* vocab_dim) {
#ifdef USE_PROFILING #ifdef USE_PROFILING
RecordEvent event( RecordEvent event(
"ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1); "ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1);
@ -349,8 +348,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
// current offset in decoder frame // current offset in decoder frame
// not used in nnet // not used in nnet
offset_ += chunk_out.shape()[1]; offset_ += chunk_out.shape()[1];
VLOG(2) << "encoder out chunk size: " << chunk_out.shape()[1] << " total: " << offset_ ; VLOG(2) << "encoder out chunk size: " << chunk_out.shape()[1]
<< " total: " << offset_;
// collects encoder outs. // collects encoder outs.
encoder_outs_.push_back(chunk_out); encoder_outs_.push_back(chunk_out);
@ -706,12 +706,13 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} }
void U2Nnet::EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const { void U2Nnet::EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const {
// list of (B=1,T,D) // list of (B=1,T,D)
int size = encoder_outs_.size(); int size = encoder_outs_.size();
VLOG(1) << "encoder_outs_ size: " << size; VLOG(1) << "encoder_outs_ size: " << size;
for (int i = 0; i < size; i++){ for (int i = 0; i < size; i++) {
const paddle::Tensor& item = encoder_outs_[i]; const paddle::Tensor& item = encoder_outs_[i];
const std::vector<int64_t> shape = item.shape(); const std::vector<int64_t> shape = item.shape();
CHECK(shape.size() == 3); CHECK(shape.size() == 3);
@ -719,16 +720,17 @@ void U2Nnet::EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_o
const int& T = shape[1]; const int& T = shape[1];
const int& D = shape[2]; const int& D = shape[2];
CHECK(B == 1) << "Only support batch one."; CHECK(B == 1) << "Only support batch one.";
VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << "," << D << ")"; VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << ","
<< D << ")";
const float *this_tensor_ptr = item.data<float>(); const float* this_tensor_ptr = item.data<float>();
for (int j = 0; j < T; j++){ for (int j = 0; j < T; j++) {
const float* cur = this_tensor_ptr + j * D; const float* cur = this_tensor_ptr + j * D;
kaldi::Vector<kaldi::BaseFloat> out(D); kaldi::Vector<kaldi::BaseFloat> out(D);
std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat)); std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat));
encoder_out->emplace_back(out); encoder_out->emplace_back(out);
} }
} }
} }
} // namespace ppspeech } // namespace ppspeech

@ -14,11 +14,11 @@
#include "base/common.h" #include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/assembler.h" #include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "decoder/param.h"
#include "nnet/u2_nnet.h" #include "nnet/u2_nnet.h"
@ -46,15 +46,16 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path; LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_rspecifier); kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier); kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(
FLAGS_nnet_encoder_outs_wspecifier);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
int32 chunk_size = int32 chunk_size = (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate + FLAGS_receptive_field_length;
FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
@ -92,9 +93,9 @@ int main(int argc, char* argv[]) {
ori_feature_len - chunk_idx * chunk_stride, chunk_size); ori_feature_len - chunk_idx * chunk_stride, chunk_size);
} }
if (this_chunk_size < receptive_field_length) { if (this_chunk_size < receptive_field_length) {
LOG(WARNING) << "utt: " << utt << " skip last " LOG(WARNING)
<< this_chunk_size << " frames, expect is " << "utt: " << utt << " skip last " << this_chunk_size
<< receptive_field_length; << " frames, expect is " << receptive_field_length;
break; break;
} }
@ -123,13 +124,17 @@ int main(int argc, char* argv[]) {
kaldi::Vector<kaldi::BaseFloat> logprobs; kaldi::Vector<kaldi::BaseFloat> logprobs;
bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim); bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim);
CHECK(isok == true); CHECK(isok == true);
for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; row_idx ++) { for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim;
row_idx++) {
kaldi::Vector<kaldi::BaseFloat> vec_tmp(vocab_dim); kaldi::Vector<kaldi::BaseFloat> vec_tmp(vocab_dim);
std::memcpy(vec_tmp.Data(), logprobs.Data() + row_idx*vocab_dim, sizeof(kaldi::BaseFloat) * vocab_dim); std::memcpy(vec_tmp.Data(),
logprobs.Data() + row_idx * vocab_dim,
sizeof(kaldi::BaseFloat) * vocab_dim);
prob_vec.push_back(vec_tmp); prob_vec.push_back(vec_tmp);
} }
VLOG(2) << "frame_idx: " << frame_idx << " elapsed: " << timer.Elapsed() << " sec."; VLOG(2) << "frame_idx: " << frame_idx
<< " elapsed: " << timer.Elapsed() << " sec.";
} }
// get encoder out // get encoder out
@ -141,7 +146,8 @@ int main(int argc, char* argv[]) {
if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) { if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) {
// the TokenWriter can not write empty string. // the TokenWriter can not write empty string.
++num_err; ++num_err;
LOG(WARNING) << " the nnet prob/encoder_out of " << utt << " is empty"; LOG(WARNING) << " the nnet prob/encoder_out of " << utt
<< " is empty";
continue; continue;
} }
@ -168,7 +174,8 @@ int main(int argc, char* argv[]) {
kaldi::Matrix<kaldi::BaseFloat> encoder_outs(nrow, ncol); kaldi::Matrix<kaldi::BaseFloat> encoder_outs(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
encoder_outs(row_idx, col_idx) = encoder_out_vec[row_idx](col_idx); encoder_outs(row_idx, col_idx) =
encoder_out_vec[row_idx](col_idx);
} }
} }
nnet_encoder_outs_writer.Write(utt, encoder_outs); nnet_encoder_outs_writer.Write(utt, encoder_outs);

@ -12,17 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "websocket/websocket_server.h"
#include "decoder/param.h" #include "decoder/param.h"
#include "websocket/websocket_server.h"
DEFINE_int32(port, 8082, "websocket listening port"); DEFINE_int32(port, 8082, "websocket listening port");
ppspeech::RecognizerResource InitRecognizerResoure() { ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource; ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale; resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags(); resource.feature_pipeline_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
return resource; return resource;
} }

@ -16,13 +16,13 @@
#include "utils/math.h" #include "utils/math.h"
#include "base/common.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <queue> #include <queue>
#include <utility> #include <utility>
#include "base/common.h"
namespace ppspeech { namespace ppspeech {
@ -89,8 +89,8 @@ void TopK(const std::vector<T>& data,
} }
template void TopK<float>(const std::vector<float>& data, template void TopK<float>(const std::vector<float>& data,
int32_t k, int32_t k,
std::vector<float>* values, std::vector<float>* values,
std::vector<int>* indices) ; std::vector<int>* indices);
} // namespace ppspeech } // namespace ppspeech
Loading…
Cancel
Save