parent
5c8725e8cd
commit
bc1b6c2e7c
@ -0,0 +1,78 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "util/parse-options.h"
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
struct CTCBeamSearchOptions {
|
||||||
|
// common
|
||||||
|
int blank;
|
||||||
|
|
||||||
|
// ds2
|
||||||
|
std::string dict_file;
|
||||||
|
std::string lm_path;
|
||||||
|
int beam_size;
|
||||||
|
BaseFloat alpha;
|
||||||
|
BaseFloat beta;
|
||||||
|
BaseFloat cutoff_prob;
|
||||||
|
int cutoff_top_n;
|
||||||
|
int num_proc_bsearch;
|
||||||
|
|
||||||
|
// u2
|
||||||
|
int first_beam_size;
|
||||||
|
int second_beam_size;
|
||||||
|
CTCBeamSearchOptions()
|
||||||
|
: blank(0),
|
||||||
|
dict_file("vocab.txt"),
|
||||||
|
lm_path(""),
|
||||||
|
alpha(1.9f),
|
||||||
|
beta(5.0),
|
||||||
|
beam_size(300),
|
||||||
|
cutoff_prob(0.99f),
|
||||||
|
cutoff_top_n(40),
|
||||||
|
num_proc_bsearch(10),
|
||||||
|
first_beam_size(10),
|
||||||
|
second_beam_size(10) {}
|
||||||
|
|
||||||
|
void Register(kaldi::OptionsItf* opts) {
|
||||||
|
std::string module = "Ds2BeamSearchConfig: ";
|
||||||
|
opts->Register("dict", &dict_file, module + "vocab file path.");
|
||||||
|
opts->Register(
|
||||||
|
"lm-path", &lm_path, module + "ngram language model path.");
|
||||||
|
opts->Register("alpha", &alpha, module + "alpha");
|
||||||
|
opts->Register("beta", &beta, module + "beta");
|
||||||
|
opts->Register("beam-size",
|
||||||
|
&beam_size,
|
||||||
|
module + "beam size for beam search method");
|
||||||
|
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
|
||||||
|
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
|
||||||
|
opts->Register(
|
||||||
|
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
|
||||||
|
|
||||||
|
opts->Register("blank", &blank, "blank id, default is 0.");
|
||||||
|
|
||||||
|
module = "U2BeamSearchConfig: ";
|
||||||
|
opts->Register(
|
||||||
|
"first-beam-size", &first_beam_size, module + "first beam size.");
|
||||||
|
opts->Register("second-beam-size",
|
||||||
|
&second_beam_size,
|
||||||
|
module + "second beam size.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,13 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
@ -0,0 +1,64 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "decoder/ctc_beam_search_opt.h"
|
||||||
|
#include "decoder/ctc_prefix_beam_search_score.h"
|
||||||
|
#include "decoder/decoder_itf.h"
|
||||||
|
|
||||||
|
#include "kaldi/decoder/decodable-itf.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class CTCPrefixBeamSearch : public DecoderInterface {
|
||||||
|
public:
|
||||||
|
explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
|
||||||
|
~CTCPrefixBeamSearch() {}
|
||||||
|
|
||||||
|
void InitDecoder();
|
||||||
|
|
||||||
|
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
|
||||||
|
|
||||||
|
std::string GetBestPath();
|
||||||
|
|
||||||
|
std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||||
|
|
||||||
|
std::string GetFinalBestPath();
|
||||||
|
|
||||||
|
int NumFrameDecoded();
|
||||||
|
|
||||||
|
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
|
||||||
|
std::vector<std::string>& nbest_words);
|
||||||
|
|
||||||
|
void AdvanceDecode(
|
||||||
|
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void ResetPrefixes();
|
||||||
|
int32 SearchOneChar(const bool& full_beam,
|
||||||
|
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||||
|
const BaseFloat& min_cutoff);
|
||||||
|
void CalculateApproxScore();
|
||||||
|
void LMRescore();
|
||||||
|
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
|
||||||
|
|
||||||
|
CTCBeamSearchOptions opts_;
|
||||||
|
size_t blank_id_;
|
||||||
|
int num_frame_decoded_;
|
||||||
|
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace basr
|
@ -0,0 +1,68 @@
|
|||||||
|
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "utils/math.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
struct PrefxiScore {
|
||||||
|
// decoding, unit in log scale
|
||||||
|
float b = -kFloatMax; // blank ending score
|
||||||
|
float nb = -kFloatMax; // none-blank ending score
|
||||||
|
|
||||||
|
// timestamp, unit in log sclae
|
||||||
|
float v_b = -kFloatMax; // viterbi blank ending score
|
||||||
|
float v_nb = -kFloatMax; // niterbi none-blank ending score
|
||||||
|
float cur_token_prob = -kFloatMax; // prob of current token
|
||||||
|
std::vector<int> times_b; // times of viterbi blank path
|
||||||
|
std::vector<int> times_nb; // times of viterbi non-blank path
|
||||||
|
|
||||||
|
// context state
|
||||||
|
bool has_context = false;
|
||||||
|
int context_state = 0;
|
||||||
|
float context_score = 0;
|
||||||
|
|
||||||
|
// decoding score, sum
|
||||||
|
float Score() const { return LogSumExp(b, nb); }
|
||||||
|
|
||||||
|
// decodign score with context bias
|
||||||
|
float TotalScore() const { return Score() + context_score; }
|
||||||
|
|
||||||
|
// timestamp score, max
|
||||||
|
float ViterbiScore() const { return std::max(v_b, v_nb); }
|
||||||
|
|
||||||
|
// get timestamp
|
||||||
|
const std::vector<int>& Times() const {
|
||||||
|
return v_b > v_nb ? times_b : times_nb;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PrefixScoreHash {
|
||||||
|
// https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
|
||||||
|
std::size_t operator()(const std::vector<int>& prefix) const {
|
||||||
|
std::size_t seed = prefix.size();
|
||||||
|
for (auto& i : prefix) {
|
||||||
|
seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||||
|
}
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using PrefixWithScoreType = std::pair<std::vector<int>, PrefixScoreHash>;
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
@ -0,0 +1,56 @@
|
|||||||
|
|
||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "base/common.h"
|
||||||
|
#include "kaldi/decoder/decodable-itf.h"
|
||||||
|
|
||||||
|
namespace ppspeech {
|
||||||
|
|
||||||
|
class DecoderInterface {
|
||||||
|
public:
|
||||||
|
virtual ~DecoderInterface() {}
|
||||||
|
|
||||||
|
virtual void InitDecoder() = 0;
|
||||||
|
|
||||||
|
virtual void Reset() = 0;
|
||||||
|
|
||||||
|
virtual void AdvanceDecode(
|
||||||
|
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
|
||||||
|
|
||||||
|
|
||||||
|
virtual std::string GetFinalBestPath() = 0;
|
||||||
|
|
||||||
|
virtual std::string GetPartialResult() = 0;
|
||||||
|
|
||||||
|
// void Decode();
|
||||||
|
|
||||||
|
// std::string GetBestPath();
|
||||||
|
// std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||||
|
|
||||||
|
// int NumFrameDecoded();
|
||||||
|
// int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
|
||||||
|
// std::vector<std::string>& nbest_words);
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
// void AdvanceDecoding(kaldi::DecodableInterface* decodable);
|
||||||
|
|
||||||
|
// current decoding frame number
|
||||||
|
int32 num_frame_decoded_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ppspeech
|
Loading…
Reference in new issue