parent
5c8725e8cd
commit
bc1b6c2e7c
@ -0,0 +1,78 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/common.h"
|
||||
#include "util/parse-options.h"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct CTCBeamSearchOptions {
|
||||
// common
|
||||
int blank;
|
||||
|
||||
// ds2
|
||||
std::string dict_file;
|
||||
std::string lm_path;
|
||||
int beam_size;
|
||||
BaseFloat alpha;
|
||||
BaseFloat beta;
|
||||
BaseFloat cutoff_prob;
|
||||
int cutoff_top_n;
|
||||
int num_proc_bsearch;
|
||||
|
||||
// u2
|
||||
int first_beam_size;
|
||||
int second_beam_size;
|
||||
CTCBeamSearchOptions()
|
||||
: blank(0),
|
||||
dict_file("vocab.txt"),
|
||||
lm_path(""),
|
||||
alpha(1.9f),
|
||||
beta(5.0),
|
||||
beam_size(300),
|
||||
cutoff_prob(0.99f),
|
||||
cutoff_top_n(40),
|
||||
num_proc_bsearch(10),
|
||||
first_beam_size(10),
|
||||
second_beam_size(10) {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
std::string module = "Ds2BeamSearchConfig: ";
|
||||
opts->Register("dict", &dict_file, module + "vocab file path.");
|
||||
opts->Register(
|
||||
"lm-path", &lm_path, module + "ngram language model path.");
|
||||
opts->Register("alpha", &alpha, module + "alpha");
|
||||
opts->Register("beta", &beta, module + "beta");
|
||||
opts->Register("beam-size",
|
||||
&beam_size,
|
||||
module + "beam size for beam search method");
|
||||
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
|
||||
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
|
||||
opts->Register(
|
||||
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
|
||||
|
||||
opts->Register("blank", &blank, "blank id, default is 0.");
|
||||
|
||||
module = "U2BeamSearchConfig: ";
|
||||
opts->Register(
|
||||
"first-beam-size", &first_beam_size, module + "first beam size.");
|
||||
opts->Register("second-beam-size",
|
||||
&second_beam_size,
|
||||
module + "second beam size.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,13 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "decoder/ctc_beam_search_opt.h"
|
||||
#include "decoder/ctc_prefix_beam_search_score.h"
|
||||
#include "decoder/decoder_itf.h"
|
||||
|
||||
#include "kaldi/decoder/decodable-itf.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class CTCPrefixBeamSearch : public DecoderInterface {
|
||||
public:
|
||||
explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
|
||||
~CTCPrefixBeamSearch() {}
|
||||
|
||||
void InitDecoder();
|
||||
|
||||
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
|
||||
|
||||
std::string GetBestPath();
|
||||
|
||||
std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||
|
||||
std::string GetFinalBestPath();
|
||||
|
||||
int NumFrameDecoded();
|
||||
|
||||
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
|
||||
std::vector<std::string>& nbest_words);
|
||||
|
||||
void AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
void ResetPrefixes();
|
||||
int32 SearchOneChar(const bool& full_beam,
|
||||
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||
const BaseFloat& min_cutoff);
|
||||
void CalculateApproxScore();
|
||||
void LMRescore();
|
||||
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
|
||||
|
||||
CTCBeamSearchOptions opts_;
|
||||
size_t blank_id_;
|
||||
int num_frame_decoded_;
|
||||
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
|
||||
};
|
||||
|
||||
} // namespace basr
|
@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "utils/math.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct PrefxiScore {
|
||||
// decoding, unit in log scale
|
||||
float b = -kFloatMax; // blank ending score
|
||||
float nb = -kFloatMax; // none-blank ending score
|
||||
|
||||
// timestamp, unit in log sclae
|
||||
float v_b = -kFloatMax; // viterbi blank ending score
|
||||
float v_nb = -kFloatMax; // niterbi none-blank ending score
|
||||
float cur_token_prob = -kFloatMax; // prob of current token
|
||||
std::vector<int> times_b; // times of viterbi blank path
|
||||
std::vector<int> times_nb; // times of viterbi non-blank path
|
||||
|
||||
// context state
|
||||
bool has_context = false;
|
||||
int context_state = 0;
|
||||
float context_score = 0;
|
||||
|
||||
// decoding score, sum
|
||||
float Score() const { return LogSumExp(b, nb); }
|
||||
|
||||
// decodign score with context bias
|
||||
float TotalScore() const { return Score() + context_score; }
|
||||
|
||||
// timestamp score, max
|
||||
float ViterbiScore() const { return std::max(v_b, v_nb); }
|
||||
|
||||
// get timestamp
|
||||
const std::vector<int>& Times() const {
|
||||
return v_b > v_nb ? times_b : times_nb;
|
||||
}
|
||||
};
|
||||
|
||||
struct PrefixScoreHash {
|
||||
// https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
|
||||
std::size_t operator()(const std::vector<int>& prefix) const {
|
||||
std::size_t seed = prefix.size();
|
||||
for (auto& i : prefix) {
|
||||
seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
using PrefixWithScoreType = std::pair<std::vector<int>, PrefixScoreHash>;
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,56 @@
|
||||
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "kaldi/decoder/decodable-itf.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class DecoderInterface {
|
||||
public:
|
||||
virtual ~DecoderInterface() {}
|
||||
|
||||
virtual void InitDecoder() = 0;
|
||||
|
||||
virtual void Reset() = 0;
|
||||
|
||||
virtual void AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
|
||||
|
||||
|
||||
virtual std::string GetFinalBestPath() = 0;
|
||||
|
||||
virtual std::string GetPartialResult() = 0;
|
||||
|
||||
// void Decode();
|
||||
|
||||
// std::string GetBestPath();
|
||||
// std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||
|
||||
// int NumFrameDecoded();
|
||||
// int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
|
||||
// std::vector<std::string>& nbest_words);
|
||||
|
||||
|
||||
private:
|
||||
// void AdvanceDecoding(kaldi::DecodableInterface* decodable);
|
||||
|
||||
// current decoding frame number
|
||||
int32 num_frame_decoded_;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
Loading…
Reference in new issue