add blank_id parameter

pull/809/head
huangyuxin 3 years ago
parent 48438066be
commit 04d9db199f

@ -35,7 +35,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer,
size_t blank_id) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
@ -48,7 +49,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// assign blank id // assign blank id
// size_t blank_id = vocabulary.size(); // size_t blank_id = vocabulary.size();
size_t blank_id = 0; // size_t blank_id = 0;
// assign space id // assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
@ -57,7 +58,6 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if ((size_t)space_id >= vocabulary.size()) { if ((size_t)space_id >= vocabulary.size()) {
space_id = -2; space_id = -2;
} }
// init prefixes' root // init prefixes' root
PathTrie root; PathTrie root;
root.score = root.log_prob_b_prev = 0.0; root.score = root.log_prob_b_prev = 0.0;
@ -218,7 +218,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer,
size_t blank_id) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool // thread pool
ThreadPool pool(num_processes); ThreadPool pool(num_processes);
@ -234,7 +235,8 @@ ctc_beam_search_decoder_batch(
beam_size, beam_size,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
ext_scorer)); ext_scorer,
blank_id));
} }
// get decoding results // get decoding results

@ -43,7 +43,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,
size_t cutoff_top_n = 40, size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_

@ -17,17 +17,18 @@
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) { const std::vector<std::string> &vocabulary,
size_t blank_id) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1, vocabulary.size(),
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
"the shape of the vocabulary"); "the shape of the vocabulary");
} }
size_t blank_id = vocabulary.size(); // size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0); std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec; std::vector<size_t> idx_vec;

@ -29,6 +29,7 @@
*/ */
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq, const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary); const std::vector<std::string>& vocabulary,
size_t blank_id);
#endif // CTC_GREEDY_DECODER_H #endif // CTC_GREEDY_DECODER_H

@ -85,9 +85,8 @@ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
# yapf: disable # yapf: disable
FILES = [ FILES = [
fn for fn in FILES fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( or fn.endswith('unittest.cc'))
'unittest.cc'))
] ]
# yapf: enable # yapf: enable

@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary): def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig. """Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
:return: Decoding result string. :return: Decoding result string.
:rtype: str :rtype: str
""" """
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
return result return result
@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoder. """Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq,
""" """
beam_results = swig_decoders.ctc_beam_search_decoder( beam_results = swig_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func) ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results return beam_results
@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decoder. """Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list :param probs_seq: 3-D list with each element as an instance of 2-D list
@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split,
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob, probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func) cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results] batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results] for beam_results in batch_beam_results]
return batch_beam_results return batch_beam_results

@ -141,7 +141,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
blank_id=0):
super().__init__() super().__init__()
self.encoder = CRNNEncoder( self.encoder = CRNNEncoder(
feat_size=feat_size, feat_size=feat_size,
@ -156,7 +157,7 @@ class DeepSpeech2Model(nn.Layer):
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size, enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank> blank_id=blank_id,
dropout_rate=0.0, dropout_rate=0.0,
reduction=True, # sum reduction=True, # sum
batch_average=True) # sum / batch_size batch_average=True) # sum / batch_size
@ -221,7 +222,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights,
blank_id=config.model.blank_id)
infos = Checkpoint().load_parameters( infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path) model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}") logger.info(f"checkpoint info: {infos}")
@ -246,7 +248,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=config.num_rnn_layers, num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size, rnn_size=config.rnn_layer_size,
use_gru=config.use_gru, use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights) share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id)
return model return model
@ -258,7 +261,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
blank_id=0):
super().__init__( super().__init__(
feat_size=feat_size, feat_size=feat_size,
dict_size=dict_size, dict_size=dict_size,
@ -266,7 +270,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=num_rnn_layers, num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size, rnn_size=rnn_size,
use_gru=use_gru, use_gru=use_gru,
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def forward(self, audio, audio_len): def forward(self, audio, audio_len):
"""export model function """export model function

@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_fc_layers=2, num_fc_layers=2,
fc_layers_size_list=[512, 256], fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False. use_gru=True, #Use gru if set True. Use simple rnn if set False.
blank_id=0, # index of blank in vocob.txt
)) ))
if config is not None: if config is not None:
config.merge_from_other_cfg(default) config.merge_from_other_cfg(default)
@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction='forward', rnn_direction='forward',
num_fc_layers=2, num_fc_layers=2,
fc_layers_size_list=[512, 256], fc_layers_size_list=[512, 256],
use_gru=False): use_gru=False,
blank_id=0):
super().__init__() super().__init__()
self.encoder = CRNNEncoder( self.encoder = CRNNEncoder(
feat_size=feat_size, feat_size=feat_size,
@ -284,7 +286,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size, enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank> blank_id=blank_id,
dropout_rate=0.0, dropout_rate=0.0,
reduction=True, # sum reduction=True, # sum
batch_average=True) # sum / batch_size batch_average=True) # sum / batch_size
@ -353,7 +355,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction=config.model.rnn_direction, rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers, num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list, fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru) use_gru=config.model.use_gru,
blank_id=config.model.blank_id)
infos = Checkpoint().load_parameters( infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path) model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}") logger.info(f"checkpoint info: {infos}")
@ -380,7 +383,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
rnn_direction=config.rnn_direction, rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers, num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list, fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru) use_gru=config.use_gru,
blank_id=config.blank_id)
return model return model
@ -394,7 +398,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
rnn_direction='forward', rnn_direction='forward',
num_fc_layers=2, num_fc_layers=2,
fc_layers_size_list=[512, 256], fc_layers_size_list=[512, 256],
use_gru=False): use_gru=False,
blank_id=0):
super().__init__( super().__init__(
feat_size=feat_size, feat_size=feat_size,
dict_size=dict_size, dict_size=dict_size,
@ -404,7 +409,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
rnn_direction=rnn_direction, rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers, num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list, fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru) use_gru=use_gru,
blank_id=blank_id)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box): chunk_state_c_box):

@ -136,7 +136,7 @@ class CTCDecoder(nn.Layer):
results = [] results = []
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder( output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list) probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id)
results.append(output_transcription) results.append(output_transcription)
return results return results
@ -216,7 +216,8 @@ class CTCDecoder(nn.Layer):
num_processes=num_processes, num_processes=num_processes,
ext_scoring_func=self._ext_scorer, ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob, cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n) cutoff_top_n=cutoff_top_n,
blank_id=self.blank_id)
results = [result[0][1] for result in beam_search_results] results = [result[0][1] for result in beam_search_results]
return results return results

@ -40,6 +40,7 @@ model:
rnn_layer_size: 1024 rnn_layer_size: 1024
use_gru: True use_gru: True
share_rnn_weights: False share_rnn_weights: False
blank_id: 0
training: training:
n_epoch: 80 n_epoch: 80

@ -36,17 +36,18 @@ collator:
model: model:
num_conv_layers: 2 num_conv_layers: 2
num_rnn_layers: 3 num_rnn_layers: 5
rnn_layer_size: 1024 rnn_layer_size: 1024
rnn_direction: forward # [forward, bidirect] rnn_direction: forward # [forward, bidirect]
num_fc_layers: 1 num_fc_layers: 0
fc_layers_size_list: 512, fc_layers_size_list: -1,
use_gru: False use_gru: False
blank_id: 0
training: training:
n_epoch: 50 n_epoch: 50
lr: 2e-3 lr: 2e-3
lr_decay: 0.91 # 0.83 lr_decay: 0.9 # 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 3.0 global_grad_clip: 3.0
log_interval: 100 log_interval: 100
@ -59,7 +60,7 @@ decoding:
error_rate_type: cer error_rate_type: cer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 1.9 alpha: 2.2 #1.9
beta: 5.0 beta: 5.0
beam_size: 300 beam_size: 300
cutoff_prob: 0.99 cutoff_prob: 0.99

@ -40,6 +40,7 @@ model:
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
blank_id: 0
training: training:
n_epoch: 50 n_epoch: 50

@ -42,6 +42,7 @@ model:
num_fc_layers: 2 num_fc_layers: 2
fc_layers_size_list: 512, 256 fc_layers_size_list: 512, 256
use_gru: False use_gru: False
blank_id: 0
training: training:
n_epoch: 50 n_epoch: 50

@ -41,6 +41,7 @@ model:
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
blank_id: 0
training: training:
n_epoch: 10 n_epoch: 10

@ -43,6 +43,7 @@ model:
num_fc_layers: 2 num_fc_layers: 2
fc_layers_size_list: 512, 256 fc_layers_size_list: 512, 256
use_gru: True use_gru: True
blank_id: 0
training: training:
n_epoch: 10 n_epoch: 10

Loading…
Cancel
Save