From 04d9db199fca74915eeeb1783b5c7453d209ff58 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 6 Sep 2021 06:20:29 +0000 Subject: [PATCH] add blank_id parameter --- .../decoders/swig/ctc_beam_search_decoder.cpp | 12 +++++++----- .../decoders/swig/ctc_beam_search_decoder.h | 6 ++++-- .../decoders/swig/ctc_greedy_decoder.cpp | 7 ++++--- deepspeech/decoders/swig/ctc_greedy_decoder.h | 3 ++- deepspeech/decoders/swig/setup.py | 5 ++--- deepspeech/decoders/swig_wrapper.py | 15 +++++++++------ deepspeech/models/ds2/deepspeech2.py | 17 +++++++++++------ deepspeech/models/ds2_online/deepspeech2.py | 18 ++++++++++++------ deepspeech/modules/ctc.py | 5 +++-- examples/aishell/s0/conf/deepspeech2.yaml | 1 + .../aishell/s0/conf/deepspeech2_online.yaml | 13 +++++++------ examples/librispeech/s0/conf/deepspeech2.yaml | 1 + .../s0/conf/deepspeech2_online.yaml | 1 + examples/tiny/s0/conf/deepspeech2.yaml | 1 + examples/tiny/s0/conf/deepspeech2_online.yaml | 1 + 15 files changed, 66 insertions(+), 40 deletions(-) diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp index 4dcc7c89..fcb1f764 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp @@ -35,7 +35,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { @@ -48,7 +49,7 @@ std::vector> ctc_beam_search_decoder( // assign blank id // size_t blank_id = vocabulary.size(); - size_t blank_id = 0; + // size_t blank_id = 0; // assign space id auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); @@ -57,7 +58,6 @@ std::vector> ctc_beam_search_decoder( if ((size_t)space_id >= vocabulary.size()) { space_id = -2; } - // init prefixes' root PathTrie root; root.score = root.log_prob_b_prev = 0.0; @@ -218,7 +218,8 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); // thread pool ThreadPool pool(num_processes); @@ -234,7 +235,8 @@ ctc_beam_search_decoder_batch( beam_size, cutoff_prob, cutoff_top_n, - ext_scorer)); + ext_scorer, + blank_id)); } // get decoding results diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.h b/deepspeech/decoders/swig/ctc_beam_search_decoder.h index c31510da..eaba9da8 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.h +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.h @@ -43,7 +43,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); /* CTC Beam Search Decoder for batch data @@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp index 1c735c42..18008cce 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp @@ -17,17 +17,18 @@ std::string ctc_greedy_decoder( const std::vector> &probs_seq, - const std::vector &vocabulary) { + const std::vector &vocabulary, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size() + 1, + vocabulary.size(), "The shape of probs_seq does not match with " "the shape of the vocabulary"); } - size_t blank_id = vocabulary.size(); + // size_t blank_id = vocabulary.size(); std::vector max_idx_vec(num_time_steps, 0); std::vector idx_vec; diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.h b/deepspeech/decoders/swig/ctc_greedy_decoder.h index 5e8c5c25..dd1b3331 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.h +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.h @@ -29,6 +29,7 @@ */ std::string ctc_greedy_decoder( const std::vector>& probs_seq, - const std::vector& vocabulary); + const std::vector& vocabulary, + size_t blank_id); #endif // CTC_GREEDY_DECODER_H diff --git a/deepspeech/decoders/swig/setup.py b/deepspeech/decoders/swig/setup.py index 8fb79296..c089f96c 100644 --- a/deepspeech/decoders/swig/setup.py +++ b/deepspeech/decoders/swig/setup.py @@ -85,9 +85,8 @@ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') # yapf: disable FILES = [ - fn for fn in FILES - if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( - 'unittest.cc')) + fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') + or fn.endswith('unittest.cc')) ] # yapf: enable diff --git a/deepspeech/decoders/swig_wrapper.py b/deepspeech/decoders/swig_wrapper.py index 3ffdb9c7..d883d430 100644 --- a/deepspeech/decoders/swig_wrapper.py +++ b/deepspeech/decoders/swig_wrapper.py @@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer): swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) -def ctc_greedy_decoder(probs_seq, vocabulary): +def ctc_greedy_decoder(probs_seq, vocabulary, blank_id): """Wrapper for ctc best path decoder in swig. :param probs_seq: 2-D list of probability distributions over each time @@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): :return: Decoding result string. :rtype: str """ - result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) + result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary, + blank_id) return result @@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq, beam_size, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time @@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq, """ beam_results = swig_decoders.ctc_beam_search_decoder( probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, - ext_scoring_func) + ext_scoring_func, blank_id) beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] return beam_results @@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split, num_processes, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the batched CTC beam search decoder. :param probs_seq: 3-D list with each element as an instance of 2-D list @@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split, batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( probs_split, vocabulary, beam_size, num_processes, cutoff_prob, - cutoff_top_n, ext_scoring_func) + cutoff_top_n, ext_scoring_func, blank_id) batch_beam_results = [[(res[0], res[1]) for res in beam_results] for beam_results in batch_beam_results] return batch_beam_results diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 5f8f3255..620d9008 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -141,7 +141,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + blank_id=0): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -156,7 +157,7 @@ class DeepSpeech2Model(nn.Layer): self.decoder = CTCDecoder( odim=dict_size, # is in vocab enc_n_units=self.encoder.output_size, - blank_id=0, # first token is + blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum batch_average=True) # sum / batch_size @@ -221,7 +222,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + share_rnn_weights=config.model.share_rnn_weights, + blank_id=config.model.blank_id) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -246,7 +248,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.num_rnn_layers, rnn_size=config.rnn_layer_size, use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights) + share_rnn_weights=config.share_rnn_weights, + blank_id=config.blank_id) return model @@ -258,7 +261,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + blank_id=0): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -266,7 +270,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=num_rnn_layers, rnn_size=rnn_size, use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + share_rnn_weights=share_rnn_weights, + blank_id=blank_id) def forward(self, audio, audio_len): """export model function diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index f597a578..f049f415 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer): num_fc_layers=2, fc_layers_size_list=[512, 256], use_gru=True, #Use gru if set True. Use simple rnn if set False. + blank_id=0, # index of blank in vocob.txt )) if config is not None: config.merge_from_other_cfg(default) @@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False): + use_gru=False, + blank_id=0): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -284,7 +286,7 @@ class DeepSpeech2ModelOnline(nn.Layer): self.decoder = CTCDecoder( odim=dict_size, # is in vocab enc_n_units=self.encoder.output_size, - blank_id=0, # first token is + blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum batch_average=True) # sum / batch_size @@ -353,7 +355,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction=config.model.rnn_direction, num_fc_layers=config.model.num_fc_layers, fc_layers_size_list=config.model.fc_layers_size_list, - use_gru=config.model.use_gru) + use_gru=config.model.use_gru, + blank_id=config.model.blank_id) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -380,7 +383,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction=config.rnn_direction, num_fc_layers=config.num_fc_layers, fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru) + use_gru=config.use_gru, + blank_id=config.blank_id) return model @@ -394,7 +398,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False): + use_gru=False, + blank_id=0): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -404,7 +409,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): rnn_direction=rnn_direction, num_fc_layers=num_fc_layers, fc_layers_size_list=fc_layers_size_list, - use_gru=use_gru) + use_gru=use_gru, + blank_id=blank_id) def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box): diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 10c04638..c330caf1 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -136,7 +136,7 @@ class CTCDecoder(nn.Layer): results = [] for i, probs in enumerate(probs_split): output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) + probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id) results.append(output_transcription) return results @@ -216,7 +216,8 @@ class CTCDecoder(nn.Layer): num_processes=num_processes, ext_scoring_func=self._ext_scorer, cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) + cutoff_top_n=cutoff_top_n, + blank_id=self.blank_id) results = [result[0][1] for result in beam_search_results] return results diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 7f0a1462..0f465a8f 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -40,6 +40,7 @@ model: rnn_layer_size: 1024 use_gru: True share_rnn_weights: False + blank_id: 0 training: n_epoch: 80 diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index fdc3a536..9f05d8dd 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -36,17 +36,18 @@ collator: model: num_conv_layers: 2 - num_rnn_layers: 3 + num_rnn_layers: 5 rnn_layer_size: 1024 rnn_direction: forward # [forward, bidirect] - num_fc_layers: 1 - fc_layers_size_list: 512, + num_fc_layers: 0 + fc_layers_size_list: -1, use_gru: False - + blank_id: 0 + training: n_epoch: 50 lr: 2e-3 - lr_decay: 0.91 # 0.83 + lr_decay: 0.9 # 0.83 weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 @@ -59,7 +60,7 @@ decoding: error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm - alpha: 1.9 + alpha: 2.2 #1.9 beta: 5.0 beam_size: 300 cutoff_prob: 0.99 diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index dab8d046..031d684d 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -40,6 +40,7 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + blank_id: 0 training: n_epoch: 50 diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml index 2e4aed40..27f59f3f 100644 --- a/examples/librispeech/s0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -42,6 +42,7 @@ model: num_fc_layers: 2 fc_layers_size_list: 512, 256 use_gru: False + blank_id: 0 training: n_epoch: 50 diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index ab9a00d9..aeb33f58 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -41,6 +41,7 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + blank_id: 0 training: n_epoch: 10 diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml index 333c2b9a..4205a04a 100644 --- a/examples/tiny/s0/conf/deepspeech2_online.yaml +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -43,6 +43,7 @@ model: num_fc_layers: 2 fc_layers_size_list: 512, 256 use_gru: True + blank_id: 0 training: n_epoch: 10