fix as comment

pull/2524/head
Hui Zhang 2 years ago
parent 850096a3a0
commit 606e2c237f

@ -37,13 +37,13 @@ struct CTCBeamSearchOptions {
// u2 // u2
int first_beam_size; int first_beam_size;
int second_beam_size; int second_beam_size;
CTCBeamSearchOptions() explicit CTCBeamSearchOptions()
: blank(0), : blank(0),
dict_file("vocab.txt"), dict_file("vocab.txt"),
lm_path(""), lm_path(""),
beam_size(300),
alpha(1.9f), alpha(1.9f),
beta(5.0), beta(5.0),
beam_size(300),
cutoff_prob(0.99f), cutoff_prob(0.99f),
cutoff_top_n(40), cutoff_top_n(40),
num_proc_bsearch(10), num_proc_bsearch(10),

@ -31,7 +31,7 @@ using paddle::platform::TracerEventType;
namespace ppspeech { namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string vocab_path, CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts) const CTCBeamSearchOptions& opts)
: opts_(opts) { : opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>( unit_table_ = std::shared_ptr<fst::SymbolTable>(
@ -55,10 +55,7 @@ void CTCPrefixBeamSearch::Reset() {
// empty hyp with Score // empty hyp with Score
std::vector<int> empty; std::vector<int> empty;
PrefixScore prefix_score; PrefixScore prefix_score;
prefix_score.b = 0.0f; // log(1) prefix_score.InitEmpty();
prefix_score.nb = -kBaseFloatMax; // log(0)
prefix_score.v_b = 0.0f; // log(1)
prefix_score.v_nb = 0.0f; // log(1)
cur_hyps_[empty] = prefix_score; cur_hyps_[empty] = prefix_score;
outputs_.emplace_back(empty); outputs_.emplace_back(empty);
@ -287,19 +284,7 @@ void CTCPrefixBeamSearch::UpdateOutputs(
int s = 0; int s = 0;
int e = 0; int e = 0;
for (int i = 0; i < input.size(); ++i) { for (int i = 0; i < input.size(); ++i) {
// if (s < start_boundaries.size() && i == start_boundaries[s]){
// // <context>
// output.emplace_back(context_graph_->start_tag_id());
// ++s;
// }
output.emplace_back(input[i]); output.emplace_back(input[i]);
// if (e < end_boundaries.size() && i == end_boundaries[e]){
// // </context>
// output.emplace_back(context_graph_->end_tag_id());
// ++e;
// }
} }
outputs_.emplace_back(output); outputs_.emplace_back(output);

@ -27,7 +27,7 @@ namespace ppspeech {
class ContextGraph; class ContextGraph;
class CTCPrefixBeamSearch : public DecoderBase { class CTCPrefixBeamSearch : public DecoderBase {
public: public:
explicit CTCPrefixBeamSearch(const std::string vocab_path, explicit CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts); const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {} ~CTCPrefixBeamSearch() {}
@ -77,7 +77,7 @@ class CTCPrefixBeamSearch : public DecoderBase {
private: private:
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
std::shared_ptr<fst::SymbolTable> unit_table_; std::shared_ptr<fst::SymbolTable> unit_table_{nullptr};
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash> std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
cur_hyps_; cur_hyps_;
@ -92,7 +92,7 @@ class CTCPrefixBeamSearch : public DecoderBase {
// Outputs contain the hypotheses_ and tags lik: <context> and </context> // Outputs contain the hypotheses_ and tags lik: <context> and </context>
std::vector<std::vector<int>> outputs_; std::vector<std::vector<int>> outputs_;
std::shared_ptr<ContextGraph> context_graph_ = nullptr; std::shared_ptr<ContextGraph> context_graph_{nullptr};
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch); DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
}; };

@ -64,12 +64,11 @@ int main(int argc, char* argv[]) {
// nnet // nnet
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path; model_opts.model_path = FLAGS_model_path;
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts)); std::shared_ptr<ppspeech::U2Nnet> nnet = std::make_shared<ppspeech::U2Nnet>(model_opts);
// decodeable // decodeable
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache()); std::shared_ptr<ppspeech::DataCache> raw_data = std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable = std::make_shared<ppspeech::Decodable>(nnet, raw_data);
new ppspeech::Decodable(nnet, raw_data));
// decoder // decoder
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;

@ -73,6 +73,13 @@ struct PrefixScore {
int prefix_len) { int prefix_len) {
CHECK(false); CHECK(false);
} }
void InitEmpty() {
b = 0.0f; // log(1)
nb = -kBaseFloatMax; // log(0)
v_b = 0.0f; // log(1)
v_nb = 0.0f; // log(1)
}
}; };
struct PrefixScoreHash { struct PrefixScoreHash {

@ -31,8 +31,8 @@ namespace ppspeech {
struct TLGDecoderOptions { struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts{}; kaldi::LatticeFasterDecoderConfig opts{};
// todo remove later, add into decode resource // todo remove later, add into decode resource
std::string word_symbol_table{}; std::string word_symbol_table;
std::string fst_path{}; std::string fst_path;
static TLGDecoderOptions InitFromFlags() { static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts; TLGDecoderOptions decoder_opts;

Loading…
Cancel
Save