add reset of decoder&frontend

pull/1570/head
Yang Zhou 2 years ago
parent b69e2222fc
commit 92c618d428

@ -63,6 +63,7 @@ int main(int argc, char* argv[]) {
int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
@ -90,6 +91,7 @@ int main(int argc, char* argv[]) {
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
++num_done;
}

@ -26,10 +26,10 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
blank_id_(-1),
space_id_(-1),
num_frame_decoded_(0),
root(nullptr) {
root_(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";
@ -40,37 +40,40 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
}
void CTCBeamSearch::Reset() {
num_frame_decoded_ = 0;
ResetPrefixes();
}
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
blank_id_ = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id = it - vocabulary_.begin();
space_id_ = it - vocabulary_.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary_.size()) {
space_id = -2;
if ((size_t)space_id_ >= vocabulary_.size()) {
space_id_ = -2;
}
}
void CTCBeamSearch::Reset() {
//num_frame_decoded_ = 0;
//ResetPrefixes();
InitDecoder();
}
ResetPrefixes();
void CTCBeamSearch::InitDecoder() {
num_frame_decoded_ = 0;
//ResetPrefixes();
prefixes_.clear();
root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get());
root_ = std::make_shared<PathTrie>();
root_->score = root_->log_prob_b_prev = 0.0;
prefixes_.push_back(root_.get());
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
root_->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
root_->set_matcher(matcher);
}
}
@ -96,12 +99,13 @@ void CTCBeamSearch::AdvanceDecode(
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) {
delete prefixes[i];
prefixes[i] = nullptr;
for (size_t i = 0; i < prefixes_.size(); i++) {
if (prefixes_[i] != nullptr) {
delete prefixes_[i];
prefixes_[i] = nullptr;
}
}
prefixes_.clear();
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
@ -115,12 +119,12 @@ int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
return result[0].second;
}
@ -153,19 +157,19 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
std::sort(prefixes_.begin(),
prefixes_.begin() + num_prefixes_,
prefix_compare);
if (num_prefixes == 0) {
if (num_prefixes_ == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
min_cutoff = prefixes_[num_prefixes_ - 1]->score +
std::log(prob[blank_id_]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
full_beam = (num_prefixes_ == beam_size);
}
vector<std::pair<size_t, float>> log_prob_idx =
@ -177,18 +181,18 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear();
prefixes_.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
root_->iterate_to_vec(prefixes_);
// only preserve top beam_size prefixes_
if (prefixes_.size() >= beam_size) {
std::nth_element(prefixes_.begin(),
prefixes_.begin() + beam_size,
prefixes_.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
for (size_t i = beam_size; i < prefixes_.size(); ++i) {
prefixes_[i]->remove();
}
} // if
num_frame_decoded_++;
@ -202,15 +206,15 @@ int32 CTCBeamSearch::SearchOneChar(
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
size_t prefixes__len = std::min(prefixes_.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes[i];
for (size_t i = 0; i < prefixes__len; ++i) {
auto prefix = prefixes_[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (c == blank_id) {
if (c == blank_id_) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
@ -238,7 +242,7 @@ int32 CTCBeamSearch::SearchOneChar(
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
(c == space_id_ || init_ext_scorer_->is_character_based())) {
PathTrie* prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
@ -266,17 +270,17 @@ int32 CTCBeamSearch::SearchOneChar(
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
double approx_ctc = prefixes_[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes[i]->get_path_vec(output);
prefixes_[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
@ -285,7 +289,7 @@ void CTCBeamSearch::CalculateApproxScore() {
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
init_ext_scorer_->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
prefixes_[i]->approx_ctc = approx_ctc;
}
}
@ -293,9 +297,9 @@ void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
auto prefix = prefixes_[i];
if (!prefix->is_empty() && prefix->character != space_id_) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) *

@ -83,10 +83,10 @@ class CTCBeamSearch {
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
size_t blank_id_;
int space_id_;
std::shared_ptr<PathTrie> root_;
std::vector<PathTrie*> prefixes_;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};

@ -36,18 +36,23 @@ class FeatureCache : public FeatureExtractorInterface {
Compute();
}
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
while (!cache_.empty()) {
cache_.pop();
}
}
private:
bool Compute();
bool finished_;
std::mutex mutex_;
size_t max_size_;
std::queue<kaldi::Vector<BaseFloat>> cache_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_;
//DISALLOW_COPY_AND_ASSGIN(FeatureCache);
// DISALLOW_COPY_AND_ASSGIN(FeatureCache);
};
} // namespace ppspeech

@ -33,7 +33,7 @@ class FeatureExtractorInterface {
virtual size_t Dim() const = 0;
virtual void SetFinished() = 0;
virtual bool IsFinished() const = 0;
// virtual void Reset();
virtual void Reset() = 0;
};
} // namespace ppspeech

@ -25,25 +25,6 @@ using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector;
// todo remove later
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
}
void CopyStdVector2Vector_(const vector<BaseFloat>& input,
Vector<BaseFloat>* output) {
if (input.empty()) return;
output->Resize(input.size());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
}
LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
@ -76,7 +57,8 @@ bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
if (flag == false || input_feats.Dim() == 0) return false;
vector<BaseFloat> input_feats_vec(input_feats.Dim());
CopyVector2StdVector_(input_feats, &input_feats_vec);
std::memcpy(input_feats_vec.data(), input_feats.Data(),
input_feats.Dim()*sizeof(BaseFloat));
vector<vector<BaseFloat>> result;
Compute(input_feats_vec, result);
int32 feat_size = 0;
@ -103,9 +85,12 @@ bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
vector<BaseFloat>* real,
vector<BaseFloat>* img) const {
Vector<BaseFloat> v_tmp;
CopyStdVector2Vector_(*v, &v_tmp);
v_tmp.Resize(v->size());
std::memcpy(v_tmp.Data(), v->data(), sizeof(BaseFloat)*(v->size()));
RealFft(&v_tmp, true);
CopyVector2StdVector_(v_tmp, v);
v->resize(v_tmp.Dim());
std::memcpy(v->data(), v_tmp.Data(), sizeof(BaseFloat)*(v->size()));
real->push_back(v->at(0));
img->push_back(0);
for (int i = 1; i < v->size() / 2; i++) {

@ -45,6 +45,9 @@ class LinearSpectrogram : public FeatureExtractorInterface {
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
}
private:
void Hanning(std::vector<kaldi::BaseFloat>* data) const;

@ -48,25 +48,6 @@ bool DecibelNormalizer::Read(kaldi::Vector<BaseFloat>* waves) {
return true;
}
// todo remove later
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx);
}
}
void CopyStdVector2Vector(const vector<BaseFloat>& input,
VectorBase<BaseFloat>* output) {
if (input.empty()) return;
assert(input.size() == output->Dim());
for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx];
}
}
bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
// calculate db rms
BaseFloat rms_db = 0.0;
@ -107,7 +88,7 @@ bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
item *= std::pow(10.0, gain / 20.0);
}
CopyStdVector2Vector(samples, waves);
std::memcpy(waves->Data(), samples.data(), sizeof(BaseFloat)*samples.size());
return true;
}

@ -52,6 +52,9 @@ class DecibelNormalizer : public FeatureExtractorInterface {
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
}
private:
bool Compute(kaldi::VectorBase<kaldi::BaseFloat>* waves) const;
@ -76,6 +79,9 @@ class CMVN : public FeatureExtractorInterface {
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
}
private:
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;

@ -34,6 +34,11 @@ class RawAudioCache : public FeatureExtractorInterface {
finished_ = true;
}
virtual bool IsFinished() const { return finished_; }
virtual void Reset() {
start_ = 0;
data_length_ = 0;
finished_ = false;
}
private:
std::vector<kaldi::BaseFloat> ring_buffer_;
@ -67,6 +72,9 @@ class RawDataCache : public FeatureExtractorInterface {
virtual void SetFinished() { finished_ = true; }
virtual bool IsFinished() const { return finished_; }
void SetDim(int32 dim) { dim_ = dim; }
virtual void Reset() {
finished_ = true;
}
private:
kaldi::Vector<kaldi::BaseFloat> data_;

@ -25,7 +25,6 @@ Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend)
: frontend_(frontend),
nnet_(nnet),
finished_(false),
frame_offset_(0),
frames_ready_(0) {}
@ -81,8 +80,10 @@ bool Decodable::FrameLogLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
}
void Decodable::Reset() {
// frontend_.Reset();
frontend_->Reset();
nnet_->Reset();
frame_offset_ = 0;
frames_ready_ = 0;
}
} // namespace ppspeech

@ -45,7 +45,6 @@ class Decodable : public kaldi::DecodableInterface {
std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;
bool finished_;
int32 frame_offset_;
int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame

Loading…
Cancel
Save