expose param cutoff_top_n

pull/2/head
Yibing Liu 7 years ago
parent 1cf61a15f4
commit 7e093ed1a3

@ -22,8 +22,6 @@ class TextFeaturizer(object):
def __init__(self, vocab_filepath): def __init__(self, vocab_filepath):
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath) vocab_filepath)
# from unicode to string
self._vocab_list = [chars.encode("utf-8") for chars in self._vocab_list]
def featurize(self, text): def featurize(self, text):
"""Convert text string to a list of token indices in char-level.Note """Convert text string to a list of token indices in char-level.Note

@ -42,8 +42,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None, ext_scoring_func=None,
nproc=False): nproc=False):
"""CTC Beam search decoder. """CTC Beam search decoder.
@ -66,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning. default 1.0, no pruning.
:type cutoff_prob: float :type cutoff_prob: float
@ -87,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq,
raise ValueError("The shape of prob_seq does not match with the " raise ValueError("The shape of prob_seq does not match with the "
"shape of the vocabulary.") "shape of the vocabulary.")
# blank_id check # blank_id assign
if not blank_id < len(probs_seq[0]): blank_id = len(vocabulary)
raise ValueError("blank_id shouldn't be greater than probs dimension")
# If the decoder called in the multiprocesses, then use the global scorer # If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch(). # instantiated in ctc_beam_search_decoder_batch().
@ -114,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq,
prob_idx = list(enumerate(probs_seq[time_step])) prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx) cutoff_len = len(prob_idx)
#If pruning is enabled #If pruning is enabled
if cutoff_prob < 1.0: if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0 cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)): for i in xrange(len(prob_idx)):
@ -122,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_len += 1 cutoff_len += 1
if cum_prob >= cutoff_prob: if cum_prob >= cutoff_prob:
break break
cutoff_len = min(cutoff_top_n, cutoff_top_n)
prob_idx = prob_idx[0:cutoff_len] prob_idx = prob_idx[0:cutoff_len]
for l in prefix_set_prev: for l in prefix_set_prev:
@ -191,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq,
def ctc_beam_search_decoder_batch(probs_split, def ctc_beam_search_decoder_batch(probs_split,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None): ext_scoring_func=None):
"""CTC beam search decoder using multiple processes. """CTC beam search decoder using multiple processes.
@ -204,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split,
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes. :param num_processes: Number of parallel processes.
:type num_processes: int :type num_processes: int
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
pool = multiprocessing.Pool(processes=num_processes) pool = multiprocessing.Pool(processes=num_processes)
results = [] results = []
for i, probs_list in enumerate(probs_split): for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
nproc) cutoff_top_n, None, nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args)) results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close() pool.close()

@ -8,7 +8,7 @@ import kenlm
import numpy as np import numpy as np
class LmScorer(object): class Scorer(object):
"""External scorer to evaluate a prefix or whole sentence in """External scorer to evaluate a prefix or whole sentence in
beam search decoding, including the score from n-gram language beam search decoding, including the score from n-gram language
model and word count. model and word count.

@ -128,7 +128,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// pruning of vacobulary // pruning of vacobulary
size_t cutoff_len = prob.size(); size_t cutoff_len = prob.size();
if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) { if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort( std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>); prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {

@ -24,6 +24,7 @@ python -u infer.py \
--alpha=2.15 \ --alpha=2.15 \
--beta=0.35 \ --beta=0.35 \
--cutoff_prob=1.0 \ --cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \

@ -33,6 +33,7 @@ python -u infer.py \
--alpha=2.15 \ --alpha=2.15 \
--beta=0.35 \ --beta=0.35 \
--cutoff_prob=1.0 \ --cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \

@ -34,6 +34,7 @@ python -u test.py \
--alpha=2.15 \ --alpha=2.15 \
--beta=0.35 \ --beta=0.35 \
--cutoff_prob=1.0 \ --cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
--share_rnn_weights=True \ --share_rnn_weights=True \

@ -23,7 +23,8 @@ add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 2.15, "Coef of LM for beam search.") add_arg('alpha', float, 2.15, "Coef of LM for beam search.")
add_arg('beta', float, 0.35, "Coef of WC for beam search.") add_arg('beta', float, 0.35, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
@ -85,6 +86,9 @@ def infer():
pretrained_model_path=args.model_path, pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights) share_rnn_weights=args.share_rnn_weights)
# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
result_transcripts = ds2_model.infer_batch( result_transcripts = ds2_model.infer_batch(
infer_data=infer_data, infer_data=infer_data,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
@ -92,7 +96,8 @@ def infer():
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list, cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)

@ -148,8 +148,8 @@ class DeepSpeech2Model(object):
return self._loss_inferer.infer(input=infer_data) return self._loss_inferer.infer(input=infer_data)
def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
beam_size, cutoff_prob, vocab_list, language_model_path, beam_size, cutoff_prob, cutoff_top_n, vocab_list,
num_processes): language_model_path, num_processes):
"""Model inference. Infer the transcription for a batch of speech """Model inference. Infer the transcription for a batch of speech
utterances. utterances.
@ -169,6 +169,10 @@ class DeepSpeech2Model(object):
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning. default 1.0, no pruning.
:type cutoff_prob: float :type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding. :param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list :type vocab_list: list
:param language_model_path: Filepath for language model. :param language_model_path: Filepath for language model.
@ -216,7 +220,8 @@ class DeepSpeech2Model(object):
beam_size=beam_size, beam_size=beam_size,
num_processes=num_processes, num_processes=num_processes,
ext_scoring_func=self._ext_scorer, ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob) cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n)
results = [result[0][1] for result in beam_search_results] results = [result[0][1] for result in beam_search_results]
else: else:

@ -24,7 +24,8 @@ add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 2.15, "Coef of LM for beam search.") add_arg('alpha', float, 2.15, "Coef of LM for beam search.")
add_arg('beta', float, 0.35, "Coef of WC for beam search.") add_arg('beta', float, 0.35, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
@ -85,6 +86,9 @@ def evaluate():
pretrained_model_path=args.model_path, pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights) share_rnn_weights=args.share_rnn_weights)
# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
error_rate_func = cer if args.error_rate_type == 'cer' else wer error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_sum, num_ins = 0.0, 0 error_sum, num_ins = 0.0, 0
for infer_data in batch_reader(): for infer_data in batch_reader():
@ -95,7 +99,8 @@ def evaluate():
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list, cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)
target_transcripts = [ target_transcripts = [

Loading…
Cancel
Save