pull/2524/head
Hui Zhang 2 years ago
parent 72c9e973a2
commit fddcd36fa0

@ -21,6 +21,7 @@
#include <iterator>
#include <numeric>
#include <thread>
#include "base/flags.h"
#include "base/log.h"
#include "paddle_inference_api.h"

@ -13,9 +13,10 @@
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/common.h"
#include "decoder/ctc_decoders/decoder_utils.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "utils/file_utils.h"
namespace ppspeech {
@ -24,10 +25,7 @@ using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts),
init_ext_scorer_(nullptr),
space_id_(-1),
root_(nullptr) {
: opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";

@ -16,11 +16,12 @@
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "absl/strings/str_join.h"
#include "base/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "utils/math.h"
#include "absl/strings/str_join.h"
#ifdef USE_PROFILING
#include "paddle/fluid/platform/profiler.h"
@ -30,12 +31,11 @@ using paddle::platform::TracerEventType;
namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(
const std::string vocab_path,
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string vocab_path,
const CTCBeamSearchOptions& opts)
: opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>(fst::SymbolTable::ReadText(vocab_path));
unit_table_ = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(vocab_path));
CHECK(unit_table_ != nullptr);
Reset();
@ -70,7 +70,6 @@ void CTCPrefixBeamSearch::Reset() {
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
@ -296,9 +295,7 @@ void CTCPrefixBeamSearch::UpdateOutputs(
outputs_.emplace_back(output);
}
void CTCPrefixBeamSearch::FinalizeSearch() {
UpdateFinalContext();
}
void CTCPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); }
void CTCPrefixBeamSearch::UpdateFinalContext() {
if (context_graph_ == nullptr) return;
@ -311,8 +308,8 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
for (const auto& prefix : hypotheses_) {
PrefixScore& prefix_score = cur_hyps_[prefix];
if (prefix_score.context_score != 0) {
prefix_score.UpdateContext(context_graph_, prefix_score, 0,
prefix.size());
prefix_score.UpdateContext(
context_graph_, prefix_score, 0, prefix.size());
}
}
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(),
@ -335,11 +332,10 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
return sentence;
}
std::string CTCPrefixBeamSearch::GetBestPath() {
return GetBestPath(0);
}
std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); }
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath(int n) {
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath(
int n) {
int hyps_size = hypotheses_.size();
CHECK(hyps_size > 0);
@ -354,17 +350,14 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
return n_best;
}
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath() {
std::vector<std::pair<double, std::string>>
CTCPrefixBeamSearch::GetNBestPath() {
return GetNBestPath(-1);
}
std::string CTCPrefixBeamSearch::GetFinalBestPath() {
return GetBestPath();
}
std::string CTCPrefixBeamSearch::GetFinalBestPath() { return GetBestPath(); }
std::string CTCPrefixBeamSearch::GetPartialResult() {
return GetBestPath();
}
std::string CTCPrefixBeamSearch::GetPartialResult() { return GetBestPath(); }
} // namespace ppspeech

@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/str_split.h"
#include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "fst/symbol-table.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"
#include "absl/strings/str_split.h"
#include "fst/symbol-table.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
@ -64,8 +64,7 @@ int main(int argc, char* argv[]) {
// nnet
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
std::shared_ptr<ppspeech::U2Nnet> nnet(
new ppspeech::U2Nnet(model_opts));
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
// decodeable
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
@ -114,9 +113,9 @@ int main(int argc, char* argv[]) {
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (this_chunk_size < receptive_field_length) {
LOG(WARNING) << "utt: " << utt << " skip last "
<< this_chunk_size << " frames, expect is "
<< receptive_field_length;
LOG(WARNING)
<< "utt: " << utt << " skip last " << this_chunk_size
<< " frames, expect is " << receptive_field_length;
break;
}

@ -33,9 +33,7 @@ void TLGDecoder::Reset() {
return;
}
void TLGDecoder::InitDecoder() {
Reset();
}
void TLGDecoder::InitDecoder() { Reset(); }
void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
@ -50,7 +48,6 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
}
std::string TLGDecoder::GetPartialResult() {
if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
@ -93,4 +90,4 @@ std::string TLGDecoder::GetFinalBestPath() {
return words;
}
}
} // namespace ppspeech

@ -15,14 +15,12 @@
// todo refactor, repalce with gtest
#include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/param.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
#include "decoder/param.h"
#include "decoder/ctc_tlg_decoder.h"
#include "kaldi/util/table-types.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
@ -47,7 +45,8 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts = ppspeech::TLGDecoderOptions::InitFromFlags();
ppspeech::TLGDecoderOptions opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);

@ -17,12 +17,12 @@
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
Recognizer::Recognizer(const RecognizerResource& resource) {

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
#include "decoder/param.h"
#include "decoder/recognizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
@ -25,7 +25,8 @@ DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.feature_pipeline_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
return resource;

@ -13,18 +13,20 @@
// limitations under the License.
#include "decoder/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource) {
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
: opts_(resource) {
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
@ -34,7 +36,8 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
CHECK(resource.vocab_path != "");
decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_;
@ -108,9 +111,9 @@ void U2Recognizer::UpdateResult(bool finish) {
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we use
// time stamp of the input(e2e model unit), which is more accurate, and it
// requires the symbol table of the e2e model used in training.
// various FST operations when building the decoding graph. So here we
// use time stamp of the input(e2e model unit), which is more accurate,
// and it requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
int offset = global_frame_offset_ * FrameShiftInMs();
@ -121,11 +124,13 @@ void U2Recognizer::UpdateResult(bool finish) {
for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]);
int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
int start =
time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
? time_stamp[j] * FrameShiftInMs() - time_stamp_gap_
: 0;
if (j > 0) {
start = (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() <
start =
(time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() <
time_stamp_gap_
? (time_stamp[j - 1] + time_stamp[j]) / 2 *
FrameShiftInMs()
@ -134,7 +139,8 @@ void U2Recognizer::UpdateResult(bool finish) {
int end = time_stamp[j] * FrameShiftInMs();
if (j < input.size() - 1) {
end = (time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() <
end =
(time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() <
time_stamp_gap_
? (time_stamp[j + 1] + time_stamp[j]) / 2 *
FrameShiftInMs()
@ -187,7 +193,8 @@ void U2Recognizer::AttentionRescoring() {
for (size_t i = 0; i < num_hyps; i++) {
VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score;
result_[i].score = opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
}
@ -196,13 +203,9 @@ void U2Recognizer::AttentionRescoring() {
<< " score: " << result_[0].score;
}
std::string U2Recognizer::GetFinalResult() {
return result_[0].sentence;
}
std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std::string U2Recognizer::GetPartialResult() {
return result_[0].sentence;
}
std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::SetFinished() {
feature_pipeline_->SetFinished();

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/u2_recognizer.h"
#include "decoder/param.h"
#include "decoder/u2_recognizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
@ -43,7 +43,8 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource = ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource);
kaldi::Timer timer;

@ -14,17 +14,18 @@
#include "frontend/audio/cmvn.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)

@ -16,16 +16,15 @@
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
@ -86,18 +85,21 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (sec): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done() && !wav_info_reader.Done(); wav_reader.Next(), wav_info_reader.Next()) {
for (; !wav_reader.Done() && !wav_info_reader.Done();
wav_reader.Next(), wav_info_reader.Next()) {
const std::string& utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
const std::string& utt2 = wav_info_reader.Key();
const kaldi::WaveInfo& wave_info = wav_info_reader.Value();
CHECK(utt == utt2) << "wav reader and wav info reader using diff rspecifier!!!";
CHECK(utt == utt2)
<< "wav reader and wav info reader using diff rspecifier!!!";
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "samples: " << wave_info.SampleCount();
LOG(INFO) << "dur: " << wave_info.Duration() << " sec";
CHECK(wave_info.SampFreq() == FLAGS_sample_rate) << "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq();
CHECK(wave_info.SampFreq() == FLAGS_sample_rate)
<< "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq();
// load first channel wav
int32 this_channel = 0;
@ -157,7 +159,8 @@ int main(int argc, char* argv[]) {
++cur_idx;
}
}
LOG(INFO) << "feat shape: " << features.NumRows() << " , " << features.NumCols();
LOG(INFO) << "feat shape: " << features.NumRows() << " , "
<< features.NumCols();
feat_writer.Write(utt, features);
// reset frontend pipeline state

@ -14,16 +14,15 @@
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");

@ -18,7 +18,8 @@ namespace ppspeech {
using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opts) {
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
: opts_(opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.to_float32));
@ -43,4 +44,4 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opt
new ppspeech::Assembler(opts.assembler_opts, std::move(cache)));
}
} // ppspeech
} // namespace ppspeech

@ -18,8 +18,8 @@ namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
using kaldi::Vector;
using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend,
@ -56,7 +56,6 @@ int32 Decodable::NumIndices() const { return 0; }
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
bool Decodable::EnsureFrameHaveComputed(int32 frame) {
// decoding frame
if (frame >= frames_ready_) {
@ -92,7 +91,8 @@ bool Decodable::AdvanceChunk() {
return true;
}
bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs, int* vocab_dim) {
bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
int* vocab_dim) {
if (AdvanceChunk() == false) {
return false;
}

@ -242,7 +242,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) {
#ifdef USE_PROFILING
RecordEvent event(
"ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1);
@ -349,7 +348,8 @@ void U2Nnet::ForwardEncoderChunkImpl(
// current offset in decoder frame
// not used in nnet
offset_ += chunk_out.shape()[1];
VLOG(2) << "encoder out chunk size: " << chunk_out.shape()[1] << " total: " << offset_ ;
VLOG(2) << "encoder out chunk size: " << chunk_out.shape()[1]
<< " total: " << offset_;
// collects encoder outs.
@ -706,7 +706,8 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
void U2Nnet::EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const {
void U2Nnet::EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const {
// list of (B=1,T,D)
int size = encoder_outs_.size();
VLOG(1) << "encoder_outs_ size: " << size;
@ -719,7 +720,8 @@ void U2Nnet::EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_o
const int& T = shape[1];
const int& D = shape[2];
CHECK(B == 1) << "Only support batch one.";
VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << "," << D << ")";
VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << ","
<< D << ")";
const float* this_tensor_ptr = item.data<float>();
for (int j = 0; j < T; j++) {

@ -14,11 +14,11 @@
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "decoder/param.h"
#include "nnet/u2_nnet.h"
@ -46,14 +46,15 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_rspecifier);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier);
kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(
FLAGS_nnet_encoder_outs_wspecifier);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
int32 chunk_size =
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate +
int32 chunk_size = (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate +
FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
@ -92,9 +93,9 @@ int main(int argc, char* argv[]) {
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (this_chunk_size < receptive_field_length) {
LOG(WARNING) << "utt: " << utt << " skip last "
<< this_chunk_size << " frames, expect is "
<< receptive_field_length;
LOG(WARNING)
<< "utt: " << utt << " skip last " << this_chunk_size
<< " frames, expect is " << receptive_field_length;
break;
}
@ -123,13 +124,17 @@ int main(int argc, char* argv[]) {
kaldi::Vector<kaldi::BaseFloat> logprobs;
bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim);
CHECK(isok == true);
for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; row_idx ++) {
for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim;
row_idx++) {
kaldi::Vector<kaldi::BaseFloat> vec_tmp(vocab_dim);
std::memcpy(vec_tmp.Data(), logprobs.Data() + row_idx*vocab_dim, sizeof(kaldi::BaseFloat) * vocab_dim);
std::memcpy(vec_tmp.Data(),
logprobs.Data() + row_idx * vocab_dim,
sizeof(kaldi::BaseFloat) * vocab_dim);
prob_vec.push_back(vec_tmp);
}
VLOG(2) << "frame_idx: " << frame_idx << " elapsed: " << timer.Elapsed() << " sec.";
VLOG(2) << "frame_idx: " << frame_idx
<< " elapsed: " << timer.Elapsed() << " sec.";
}
// get encoder out
@ -141,7 +146,8 @@ int main(int argc, char* argv[]) {
if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) {
// the TokenWriter can not write empty string.
++num_err;
LOG(WARNING) << " the nnet prob/encoder_out of " << utt << " is empty";
LOG(WARNING) << " the nnet prob/encoder_out of " << utt
<< " is empty";
continue;
}
@ -168,7 +174,8 @@ int main(int argc, char* argv[]) {
kaldi::Matrix<kaldi::BaseFloat> encoder_outs(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
encoder_outs(row_idx, col_idx) = encoder_out_vec[row_idx](col_idx);
encoder_outs(row_idx, col_idx) =
encoder_out_vec[row_idx](col_idx);
}
}
nnet_encoder_outs_writer.Write(utt, encoder_outs);

@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_server.h"
#include "decoder/param.h"
#include "websocket/websocket_server.h"
DEFINE_int32(port, 8082, "websocket listening port");
ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.feature_pipeline_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
return resource;

@ -16,13 +16,13 @@
#include "utils/math.h"
#include "base/common.h"
#include <algorithm>
#include <cmath>
#include <queue>
#include <utility>
#include "base/common.h"
namespace ppspeech {

Loading…
Cancel
Save