From 3ecf1ad4f5f2d61d31184c84be7ccfffc385611f Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 12 Jan 2018 21:33:00 +0800 Subject: [PATCH] Decouple ext scorer init & inference & decoding for the convenience of tuning --- examples/librispeech/run_tune.sh | 2 +- infer.py | 32 ++++--- model_utils/model.py | 158 +++++++++++++++++-------------- test.py | 31 ++++-- tools/tune.py | 104 +++++--------------- 5 files changed, 153 insertions(+), 174 deletions(-) diff --git a/examples/librispeech/run_tune.sh b/examples/librispeech/run_tune.sh index c3695d1c..9fc9cbb9 100644 --- a/examples/librispeech/run_tune.sh +++ b/examples/librispeech/run_tune.sh @@ -7,7 +7,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 \ python -u tools/tune.py \ --num_batches=-1 \ --batch_size=128 \ ---trainer_count=8 \ +--trainer_count=4 \ --beam_size=500 \ --num_proc_bsearch=12 \ --num_conv_layers=2 \ diff --git a/infer.py b/infer.py index b801c507..1539fbaa 100644 --- a/infer.py +++ b/infer.py @@ -90,18 +90,26 @@ def infer(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] - result_transcripts = ds2_model.infer_batch( - infer_data=infer_data, - decoding_method=args.decoding_method, - beam_alpha=args.alpha, - beam_beta=args.beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - cutoff_top_n=args.cutoff_top_n, - vocab_list=vocab_list, - language_model_path=args.lang_model_path, - num_processes=args.num_proc_bsearch, - feeding_dict=data_generator.feeding) + probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + feeding_dict=data_generator.feeding) + if args.decoding_method == "ctc_greedy": + ds2_model.logger.info("start inference ...") + result_transcripts = ds2_model.infer_batch_greedy( + probs_split=probs_split, + vocab_list=vocab_list) + else: + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) + ds2_model.logger.info("start inference ...") + result_transcripts = ds2_model.infer_batch_beam_search( + probs_split=probs_split, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, + num_processes=args.num_proc_bsearch) error_rate_func = cer if args.error_rate_type == 'cer' else wer target_transcripts = [data[1] for data in infer_data] diff --git a/model_utils/model.py b/model_utils/model.py index 85d50053..f6d3ef05 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -173,43 +173,19 @@ class DeepSpeech2Model(object): # run inference return self._loss_inferer.infer(input=infer_data) - def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, vocab_list, - language_model_path, num_processes, feeding_dict): - """Model inference. Infer the transcription for a batch of speech - utterances. + def infer_probs_batch(self, infer_data, feeding_dict): + """Infer the prob matrices for a batch of speech utterances. :param infer_data: List of utterances to infer, with each utterance consisting of a tuple of audio features and transcription text (empty string). :type infer_data: list - :param decoding_method: Decoding method name, 'ctc_greedy' or - 'ctc_beam_search'. - :param decoding_method: string - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float - :param beam_size: Width for Beam search. - :type beam_size: int - :param cutoff_prob: Cutoff probability in pruning, - default 1.0, no pruning. - :type cutoff_prob: float - :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n - characters with highest probs in vocabulary will be - used in beam search, default 40. - :type cutoff_top_n: int - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :param language_model_path: Filepath for language model. - :type language_model_path: basestring|None - :param num_processes: Number of processes (CPU) for decoder. - :type num_processes: int :param feeding_dict: Feeding is a map of field name and tuple index of the data that reader returns. :type feeding_dict: dict|list - :return: List of transcription texts. - :rtype: List of basestring + :return: List of 2-D probability matrix, and each consists of prob + vectors for one speech utterancce. + :rtype: List of matrix """ # define inferer if self._inferer == None: @@ -227,49 +203,91 @@ class DeepSpeech2Model(object): infer_results[start_pos[i]:start_pos[i + 1]] for i in xrange(0, len(adapted_infer_data)) ] - # run decoder + return probs_split + + def infer_batch_greedy(self, probs_split, vocab_list): + """ + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :return: List of transcription texts. + :rtype: List of basestring + """ results = [] - if decoding_method == "ctc_greedy": - # best path decode - for i, probs in enumerate(probs_split): - output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) - results.append(output_transcription) - elif decoding_method == "ctc_beam_search": - # initialize external scorer - if self._ext_scorer == None: - self._loaded_lm_path = language_model_path - self.logger.info("begin to initialize the external scorer " - "for decoding") - self._ext_scorer = Scorer(beam_alpha, beam_beta, - language_model_path, vocab_list) - - lm_char_based = self._ext_scorer.is_character_based() - lm_max_order = self._ext_scorer.get_max_order() - lm_dict_size = self._ext_scorer.get_dict_size() - self.logger.info("language model: " - "is_character_based = %d," % lm_char_based + - " max_order = %d," % lm_max_order + - " dict_size = %d" % lm_dict_size) - self.logger.info("end initializing scorer. Start decoding ...") - else: - self._ext_scorer.reset_params(beam_alpha, beam_beta) - assert self._loaded_lm_path == language_model_path - # beam search decode - num_processes = min(num_processes, len(probs_split)) - beam_search_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=vocab_list, - beam_size=beam_size, - num_processes=num_processes, - ext_scoring_func=self._ext_scorer, - cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) - - results = [result[0][1] for result in beam_search_results] + for i, probs in enumerate(probs_split): + output_transcription = ctc_greedy_decoder( + probs_seq=probs, vocabulary=vocab_list) + results.append(output_transcription) + return results + + def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, + vocab_list): + """Initialize the external scorer. + + """ + if language_model_path != '': + self.logger.info("begin to initialize the external scorer " + "for decoding") + self._ext_scorer = Scorer(beam_alpha, beam_beta, + language_model_path, vocab_list) + lm_char_based = self._ext_scorer.is_character_based() + lm_max_order = self._ext_scorer.get_max_order() + lm_dict_size = self._ext_scorer.get_dict_size() + self.logger.info("language model: " + "is_character_based = %d," % lm_char_based + + " max_order = %d," % lm_max_order + + " dict_size = %d" % lm_dict_size) + self.logger.info("end initializing scorer") else: - raise ValueError("Decoding method [%s] is not supported." % - decoding_method) + self._ext_scorer = None + self.logger.info("no language model provided, " + "decoding by pure beam search without scorer.") + + def infer_batch_beam_search(self, probs_split, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + vocab_list, num_processes): + """Model inference. Infer the transcription for a batch of speech + utterances. + + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param beam_size: Width for Beam search. + :type beam_size: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :param num_processes: Number of processes (CPU) for decoder. + :type num_processes: int + :return: List of transcription texts. + :rtype: List of basestring + """ + if self._ext_scorer != None: + self._ext_scorer.reset_params(beam_alpha, beam_beta) + # beam search decode + num_processes = min(num_processes, len(probs_split)) + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=beam_size, + num_processes=num_processes, + ext_scoring_func=self._ext_scorer, + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n) + + results = [result[0][1] for result in beam_search_results] return results def _adapt_feeding_dict(self, feeding_dict): diff --git a/test.py b/test.py index 5cf76648..24ce54a2 100644 --- a/test.py +++ b/test.py @@ -90,22 +90,33 @@ def evaluate(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + if args.decoding_method == "ctc_beam_search": + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) errors_func = char_errors if args.error_rate_type == 'cer' else word_errors errors_sum, len_refs, num_ins = 0.0, 0, 0 + ds2_model.logger.info("start evaluation ...") for infer_data in batch_reader(): - result_transcripts = ds2_model.infer_batch( + probs_split = ds2_model.infer_probs_batch( infer_data=infer_data, - decoding_method=args.decoding_method, - beam_alpha=args.alpha, - beam_beta=args.beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - cutoff_top_n=args.cutoff_top_n, - vocab_list=vocab_list, - language_model_path=args.lang_model_path, - num_processes=args.num_proc_bsearch, feeding_dict=data_generator.feeding) + + if args.decoding_method == "ctc_greedy": + result_transcripts = ds2_model.infer_batch_greedy( + probs_split=probs_split, + vocab_list=vocab_list) + else: + result_transcripts = ds2_model.infer_batch_beam_search( + probs_split=probs_split, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, + num_processes=args.num_proc_bsearch) target_transcripts = [data[1] for data in infer_data] + for target, result in zip(target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors diff --git a/tools/tune.py b/tools/tune.py index b1323319..83978be8 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -13,9 +13,7 @@ import logging import paddle.v2 as paddle import _init_paths from data_utils.data import DataGenerator -from decoders.swig_wrapper import Scorer -from decoders.swig_wrapper import ctc_beam_search_decoder_batch -from model_utils.model import deep_speech_v2_network +from model_utils.model import DeepSpeech2Model from utils.error_rate import char_errors, word_errors from utils.utility import add_arguments, print_arguments @@ -88,40 +86,7 @@ def tune(): augmentation_config='{}', specgram_type=args.specgram_type, num_threads=args.num_proc_data, - keep_transcription_text=True, - num_conv_layers=args.num_conv_layers) - - audio_data = paddle.layer.data( - name="audio_spectrogram", - type=paddle.data_type.dense_array(161 * 161)) - text_data = paddle.layer.data( - name="transcript_text", - type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) - seq_offset_data = paddle.layer.data( - name='sequence_offset', - type=paddle.data_type.integer_value_sequence(1)) - seq_len_data = paddle.layer.data( - name='sequence_length', - type=paddle.data_type.integer_value_sequence(1)) - index_range_datas = [] - for i in xrange(args.num_rnn_layers): - index_range_datas.append( - paddle.layer.data( - name='conv%d_index_range' % i, - type=paddle.data_type.dense_vector(6))) - - output_probs, _ = deep_speech_v2_network( - audio_data=audio_data, - text_data=text_data, - seq_offset_data=seq_offset_data, - seq_len_data=seq_len_data, - index_range_datas=index_range_datas, - dict_size=data_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_size=args.rnn_layer_size, - use_gru=args.use_gru, - share_rnn_weights=args.share_rnn_weights) + keep_transcription_text=True) batch_reader = data_generator.batch_reader_creator( manifest_path=args.tune_manifest, @@ -129,35 +94,17 @@ def tune(): sortagrad=False, shuffle_method=None) - # load parameters - if not os.path.isfile(args.model_path): - raise IOError("Invaid model path: %s" % args.model_path) - parameters = paddle.parameters.Parameters.from_tar( - gzip.open(args.model_path)) + ds2_model = DeepSpeech2Model( + vocab_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_layer_size=args.rnn_layer_size, + use_gru=args.use_gru, + pretrained_model_path=args.model_path, + share_rnn_weights=args.share_rnn_weights) - inferer = paddle.inference.Inference( - output_layer=output_probs, parameters=parameters) # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] - - # init logger - logger = logging.getLogger("") - logger.setLevel(level=logging.INFO) - # init external scorer - logger.info("begin to initialize the external scorer for tuning") - if not os.path.isfile(args.lang_model_path): - raise IOError("Invaid language model path: %s" % args.lang_model_path) - ext_scorer = Scorer( - alpha=args.alpha_from, - beta=args.beta_from, - model_path=args.lang_model_path, - vocabulary=vocab_list) - logger.info("language model: " - "is_character_based = %d," % ext_scorer.is_character_based() + - " max_order = %d," % ext_scorer.get_max_order() + - " dict_size = %d" % ext_scorer.get_dict_size()) - logger.info("end initializing scorer. Start tuning ...") - errors_func = char_errors if args.error_rate_type == 'cer' else word_errors # create grid for search cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) @@ -168,37 +115,32 @@ def tune(): err_sum = [0.0 for i in xrange(len(params_grid))] err_ave = [0.0 for i in xrange(len(params_grid))] num_ins, len_refs, cur_batch = 0, 0, 0 + # initialize external scorer + ds2_model.init_ext_scorer(args.alpha_from, args.beta_from, + args.lang_model_path, vocab_list) ## incremental tuning parameters over multiple batches + ds2_model.logger.info("start tuning ...") for infer_data in batch_reader(): if (args.num_batches >= 0) and (cur_batch >= args.num_batches): break - infer_results = inferer.infer(input=infer_data, - feeding=data_generator.feeding) - start_pos = [0] * (len(infer_data) + 1) - for i in xrange(len(infer_data)): - start_pos[i + 1] = start_pos[i] + infer_data[i][3][0] - probs_split = [ - infer_results[start_pos[i]:start_pos[i + 1]] - for i in xrange(0, len(infer_data)) - ] - + probs_split = ds2_model.infer_probs_batch( + infer_data=infer_data, + feeding_dict=data_generator.feeding) target_transcripts = [ data[1] for data in infer_data ] num_ins += len(target_transcripts) # grid search for index, (alpha, beta) in enumerate(params_grid): - # reset alpha & beta - ext_scorer.reset_params(alpha, beta) - beam_search_results = ctc_beam_search_decoder_batch( + result_transcripts = ds2_model.infer_batch_beam_search( probs_split=probs_split, - vocabulary=vocab_list, + beam_alpha=alpha, + beam_beta=beta, beam_size=args.beam_size, - num_processes=args.num_proc_bsearch, cutoff_prob=args.cutoff_prob, cutoff_top_n=args.cutoff_top_n, - ext_scoring_func=ext_scorer, ) + vocab_list=vocab_list, + num_processes=args.num_proc_bsearch) - result_transcripts = [res[0][1] for res in beam_search_results] for target, result in zip(target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) err_sum[index] += errors @@ -235,7 +177,7 @@ def tune(): % (cur_batch, "%.3f" % params_grid[min_index][0], "%.3f" % params_grid[min_index][1])) - logger.info("finish tuning") + ds2_model.logger.info("finish tuning") def main():