diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc index cffca39a..44127c73 100644 --- a/speechx/examples/decoder/offline_decoder_main.cc +++ b/speechx/examples/decoder/offline_decoder_main.cc @@ -63,6 +63,7 @@ int main(int argc, char* argv[]) { int32 chunk_size = 35; decoder.InitDecoder(); + for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); const kaldi::Matrix feature = feature_reader.Value(); @@ -90,6 +91,7 @@ int main(int argc, char* argv[]) { result = decoder.GetFinalBestPath(); KALDI_LOG << " the result of " << utt << " is " << result; decodable->Reset(); + decoder.Reset(); ++num_done; } diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 8106b710..7bbb9506 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -26,10 +26,10 @@ using FSTMATCH = fst::SortedMatcher; CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) : opts_(opts), init_ext_scorer_(nullptr), - blank_id(-1), - space_id(-1), + blank_id_(-1), + space_id_(-1), num_frame_decoded_(0), - root(nullptr) { + root_(nullptr) { LOG(INFO) << "dict path: " << opts_.dict_file; if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { LOG(INFO) << "load the dict failed"; @@ -40,37 +40,40 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) LOG(INFO) << "language model path: " << opts_.lm_path; init_ext_scorer_ = std::make_shared( opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); -} - -void CTCBeamSearch::Reset() { - num_frame_decoded_ = 0; - ResetPrefixes(); -} -void CTCBeamSearch::InitDecoder() { - blank_id = 0; + blank_id_ = 0; auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); - space_id = it - vocabulary_.begin(); + space_id_ = it - vocabulary_.begin(); // if no space in vocabulary - if ((size_t)space_id >= vocabulary_.size()) { - space_id = -2; + if ((size_t)space_id_ >= vocabulary_.size()) { + space_id_ = -2; } +} + +void CTCBeamSearch::Reset() { + //num_frame_decoded_ = 0; + //ResetPrefixes(); + InitDecoder(); +} - ResetPrefixes(); +void CTCBeamSearch::InitDecoder() { + num_frame_decoded_ = 0; + //ResetPrefixes(); + prefixes_.clear(); - root = std::make_shared(); - root->score = root->log_prob_b_prev = 0.0; - prefixes.push_back(root.get()); + root_ = std::make_shared(); + root_->score = root_->log_prob_b_prev = 0.0; + prefixes_.push_back(root_.get()); if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { auto fst_dict = static_cast(init_ext_scorer_->dictionary); fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); - root->set_dictionary(dict_ptr); + root_->set_dictionary(dict_ptr); auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root->set_matcher(matcher); + root_->set_matcher(matcher); } } @@ -96,12 +99,13 @@ void CTCBeamSearch::AdvanceDecode( } void CTCBeamSearch::ResetPrefixes() { - for (size_t i = 0; i < prefixes.size(); i++) { - if (prefixes[i] != nullptr) { - delete prefixes[i]; - prefixes[i] = nullptr; + for (size_t i = 0; i < prefixes_.size(); i++) { + if (prefixes_[i] != nullptr) { + delete prefixes_[i]; + prefixes_[i] = nullptr; } } + prefixes_.clear(); } int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, @@ -115,12 +119,12 @@ int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, } vector> CTCBeamSearch::GetNBestPath() { - return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); + return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); } string CTCBeamSearch::GetBestPath() { std::vector> result; - result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); + result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); return result[0].second; } @@ -153,19 +157,19 @@ void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { float min_cutoff = -NUM_FLT_INF; bool full_beam = false; if (init_ext_scorer_ != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), - prefixes.begin() + num_prefixes, + size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); + std::sort(prefixes_.begin(), + prefixes_.begin() + num_prefixes_, prefix_compare); - if (num_prefixes == 0) { + if (num_prefixes_ == 0) { continue; } - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - + min_cutoff = prefixes_[num_prefixes_ - 1]->score + + std::log(prob[blank_id_]) - std::max(0.0, init_ext_scorer_->beta); - full_beam = (num_prefixes == beam_size); + full_beam = (num_prefixes_ == beam_size); } vector> log_prob_idx = @@ -177,18 +181,18 @@ void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); } - prefixes.clear(); + prefixes_.clear(); // update log probs - root->iterate_to_vec(prefixes); - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), + root_->iterate_to_vec(prefixes_); + // only preserve top beam_size prefixes_ + if (prefixes_.size() >= beam_size) { + std::nth_element(prefixes_.begin(), + prefixes_.begin() + beam_size, + prefixes_.end(), prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); + for (size_t i = beam_size; i < prefixes_.size(); ++i) { + prefixes_[i]->remove(); } } // if num_frame_decoded_++; @@ -202,15 +206,15 @@ int32 CTCBeamSearch::SearchOneChar( size_t beam_size = opts_.beam_size; const auto& c = log_prob_idx.first; const auto& log_prob_c = log_prob_idx.second; - size_t prefixes_len = std::min(prefixes.size(), beam_size); + size_t prefixes__len = std::min(prefixes_.size(), beam_size); - for (size_t i = 0; i < prefixes_len; ++i) { - auto prefix = prefixes[i]; + for (size_t i = 0; i < prefixes__len; ++i) { + auto prefix = prefixes_[i]; if (full_beam && log_prob_c + prefix->score < min_cutoff) { break; } - if (c == blank_id) { + if (c == blank_id_) { prefix->log_prob_b_cur = log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); continue; @@ -238,7 +242,7 @@ int32 CTCBeamSearch::SearchOneChar( // language model scoring if (init_ext_scorer_ != nullptr && - (c == space_id || init_ext_scorer_->is_character_based())) { + (c == space_id_ || init_ext_scorer_->is_character_based())) { PathTrie* prefix_to_score = nullptr; // skip scoring the space if (init_ext_scorer_->is_character_based()) { @@ -266,17 +270,17 @@ int32 CTCBeamSearch::SearchOneChar( void CTCBeamSearch::CalculateApproxScore() { size_t beam_size = opts_.beam_size; - size_t num_prefixes = std::min(prefixes.size(), beam_size); + size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare); // compute aproximate ctc score as the return score, without affecting the // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; + for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { + double approx_ctc = prefixes_[i]->score; if (init_ext_scorer_ != nullptr) { vector output; - prefixes[i]->get_path_vec(output); + prefixes_[i]->get_path_vec(output); auto prefix_length = output.size(); auto words = init_ext_scorer_->split_labels(output); // remove word insert @@ -285,7 +289,7 @@ void CTCBeamSearch::CalculateApproxScore() { approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha; } - prefixes[i]->approx_ctc = approx_ctc; + prefixes_[i]->approx_ctc = approx_ctc; } } @@ -293,9 +297,9 @@ void CTCBeamSearch::LMRescore() { size_t beam_size = opts_.beam_size; if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { + for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { + auto prefix = prefixes_[i]; + if (!prefix->is_empty() && prefix->character != space_id_) { float score = 0.0; vector ngram = init_ext_scorer_->make_ngram(prefix); score = init_ext_scorer_->get_log_cond_prob(ngram) * diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 53700e27..451f33c0 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -83,10 +83,10 @@ class CTCBeamSearch { CTCBeamSearchOptions opts_; std::shared_ptr init_ext_scorer_; // todo separate later std::vector vocabulary_; // todo remove later - size_t blank_id; - int space_id; - std::shared_ptr root; - std::vector prefixes; + size_t blank_id_; + int space_id_; + std::shared_ptr root_; + std::vector prefixes_; int num_frame_decoded_; DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); }; diff --git a/speechx/speechx/frontend/feature_cache.h b/speechx/speechx/frontend/feature_cache.h index 03b11f57..459134ee 100644 --- a/speechx/speechx/frontend/feature_cache.h +++ b/speechx/speechx/frontend/feature_cache.h @@ -36,18 +36,23 @@ class FeatureCache : public FeatureExtractorInterface { Compute(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { + base_extractor_->Reset(); + while (!cache_.empty()) { + cache_.pop(); + } + } private: bool Compute(); - bool finished_; std::mutex mutex_; size_t max_size_; std::queue> cache_; std::unique_ptr base_extractor_; std::condition_variable ready_feed_condition_; std::condition_variable ready_read_condition_; - //DISALLOW_COPY_AND_ASSGIN(FeatureCache); + // DISALLOW_COPY_AND_ASSGIN(FeatureCache); }; } // namespace ppspeech diff --git a/speechx/speechx/frontend/feature_extractor_interface.h b/speechx/speechx/frontend/feature_extractor_interface.h index 64cc67f3..cb6fec1b 100644 --- a/speechx/speechx/frontend/feature_extractor_interface.h +++ b/speechx/speechx/frontend/feature_extractor_interface.h @@ -33,7 +33,7 @@ class FeatureExtractorInterface { virtual size_t Dim() const = 0; virtual void SetFinished() = 0; virtual bool IsFinished() const = 0; - // virtual void Reset(); + virtual void Reset() = 0; }; } // namespace ppspeech diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc index 7491716c..b8a18e02 100644 --- a/speechx/speechx/frontend/linear_spectrogram.cc +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -25,25 +25,6 @@ using kaldi::VectorBase; using kaldi::Matrix; using std::vector; -// todo remove later -void CopyVector2StdVector_(const VectorBase& input, - vector* output) { - if (input.Dim() == 0) return; - output->resize(input.Dim()); - for (size_t idx = 0; idx < input.Dim(); ++idx) { - (*output)[idx] = input(idx); - } -} - -void CopyStdVector2Vector_(const vector& input, - Vector* output) { - if (input.empty()) return; - output->Resize(input.size()); - for (size_t idx = 0; idx < input.size(); ++idx) { - (*output)(idx) = input[idx]; - } -} - LinearSpectrogram::LinearSpectrogram( const LinearSpectrogramOptions& opts, std::unique_ptr base_extractor) { @@ -76,7 +57,8 @@ bool LinearSpectrogram::Read(Vector* feats) { if (flag == false || input_feats.Dim() == 0) return false; vector input_feats_vec(input_feats.Dim()); - CopyVector2StdVector_(input_feats, &input_feats_vec); + std::memcpy(input_feats_vec.data(), input_feats.Data(), + input_feats.Dim()*sizeof(BaseFloat)); vector> result; Compute(input_feats_vec, result); int32 feat_size = 0; @@ -103,9 +85,12 @@ bool LinearSpectrogram::NumpyFft(vector* v, vector* real, vector* img) const { Vector v_tmp; - CopyStdVector2Vector_(*v, &v_tmp); + v_tmp.Resize(v->size()); + std::memcpy(v_tmp.Data(), v->data(), sizeof(BaseFloat)*(v->size())); RealFft(&v_tmp, true); - CopyVector2StdVector_(v_tmp, v); + v->resize(v_tmp.Dim()); + std::memcpy(v->data(), v_tmp.Data(), sizeof(BaseFloat)*(v->size())); + real->push_back(v->at(0)); img->push_back(0); for (int i = 1; i < v->size() / 2; i++) { diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h index 790263d9..b2bb414d 100644 --- a/speechx/speechx/frontend/linear_spectrogram.h +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -45,6 +45,9 @@ class LinearSpectrogram : public FeatureExtractorInterface { virtual size_t Dim() const { return dim_; } virtual void SetFinished() { base_extractor_->SetFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { + base_extractor_->Reset(); + } private: void Hanning(std::vector* data) const; diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc index fbb2b645..285c8e03 100644 --- a/speechx/speechx/frontend/normalizer.cc +++ b/speechx/speechx/frontend/normalizer.cc @@ -48,25 +48,6 @@ bool DecibelNormalizer::Read(kaldi::Vector* waves) { return true; } -// todo remove later -void CopyVector2StdVector(const kaldi::VectorBase& input, - vector* output) { - if (input.Dim() == 0) return; - output->resize(input.Dim()); - for (size_t idx = 0; idx < input.Dim(); ++idx) { - (*output)[idx] = input(idx); - } -} - -void CopyStdVector2Vector(const vector& input, - VectorBase* output) { - if (input.empty()) return; - assert(input.size() == output->Dim()); - for (size_t idx = 0; idx < input.size(); ++idx) { - (*output)(idx) = input[idx]; - } -} - bool DecibelNormalizer::Compute(VectorBase* waves) const { // calculate db rms BaseFloat rms_db = 0.0; @@ -107,7 +88,7 @@ bool DecibelNormalizer::Compute(VectorBase* waves) const { item *= std::pow(10.0, gain / 20.0); } - CopyStdVector2Vector(samples, waves); + std::memcpy(waves->Data(), samples.data(), sizeof(BaseFloat)*samples.size()); return true; } diff --git a/speechx/speechx/frontend/normalizer.h b/speechx/speechx/frontend/normalizer.h index b9daa853..24542eba 100644 --- a/speechx/speechx/frontend/normalizer.h +++ b/speechx/speechx/frontend/normalizer.h @@ -52,6 +52,9 @@ class DecibelNormalizer : public FeatureExtractorInterface { virtual size_t Dim() const { return dim_; } virtual void SetFinished() { base_extractor_->SetFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { + base_extractor_->Reset(); + } private: bool Compute(kaldi::VectorBase* waves) const; @@ -76,6 +79,9 @@ class CMVN : public FeatureExtractorInterface { virtual size_t Dim() const { return dim_; } virtual void SetFinished() { base_extractor_->SetFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { + base_extractor_->Reset(); + } private: void Compute(kaldi::VectorBase* feats) const; diff --git a/speechx/speechx/frontend/raw_audio.h b/speechx/speechx/frontend/raw_audio.h index 1a326b3c..7726f825 100644 --- a/speechx/speechx/frontend/raw_audio.h +++ b/speechx/speechx/frontend/raw_audio.h @@ -34,6 +34,11 @@ class RawAudioCache : public FeatureExtractorInterface { finished_ = true; } virtual bool IsFinished() const { return finished_; } + virtual void Reset() { + start_ = 0; + data_length_ = 0; + finished_ = false; + } private: std::vector ring_buffer_; @@ -67,6 +72,9 @@ class RawDataCache : public FeatureExtractorInterface { virtual void SetFinished() { finished_ = true; } virtual bool IsFinished() const { return finished_; } void SetDim(int32 dim) { dim_ = dim; } + virtual void Reset() { + finished_ = true; + } private: kaldi::Vector data_; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 79c896aa..3cc07f38 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -25,7 +25,6 @@ Decodable::Decodable(const std::shared_ptr& nnet, const std::shared_ptr& frontend) : frontend_(frontend), nnet_(nnet), - finished_(false), frame_offset_(0), frames_ready_(0) {} @@ -81,8 +80,10 @@ bool Decodable::FrameLogLikelihood(int32 frame, vector* likelihood) { } void Decodable::Reset() { - // frontend_.Reset(); + frontend_->Reset(); nnet_->Reset(); + frame_offset_ = 0; + frames_ready_ = 0; } } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 72d194b9..7938b582 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -45,7 +45,6 @@ class Decodable : public kaldi::DecodableInterface { std::shared_ptr nnet_; kaldi::Matrix nnet_cache_; // std::vector> nnet_cache_; - bool finished_; int32 frame_offset_; int32 frames_ready_; // todo: feature frame mismatch with nnet inference frame