Decouple ext scorer init & inference & decoding for the convenience of tuning

pull/122/head
Yibing Liu 7 years ago
parent f896f0ef05
commit 3ecf1ad4f5

@ -7,7 +7,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -u tools/tune.py \ python -u tools/tune.py \
--num_batches=-1 \ --num_batches=-1 \
--batch_size=128 \ --batch_size=128 \
--trainer_count=8 \ --trainer_count=4 \
--beam_size=500 \ --beam_size=500 \
--num_proc_bsearch=12 \ --num_proc_bsearch=12 \
--num_conv_layers=2 \ --num_conv_layers=2 \

@ -90,18 +90,26 @@ def infer():
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
result_transcripts = ds2_model.infer_batch( probs_split = ds2_model.infer_probs_batch(infer_data=infer_data,
infer_data=infer_data, feeding_dict=data_generator.feeding)
decoding_method=args.decoding_method, if args.decoding_method == "ctc_greedy":
ds2_model.logger.info("start inference ...")
result_transcripts = ds2_model.infer_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
ds2_model.logger.info("start inference ...")
result_transcripts = ds2_model.infer_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha, beam_alpha=args.alpha,
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n, cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list, vocab_list=vocab_list,
language_model_path=args.lang_model_path, num_processes=args.num_proc_bsearch)
num_processes=args.num_proc_bsearch,
feeding_dict=data_generator.feeding)
error_rate_func = cer if args.error_rate_type == 'cer' else wer error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [data[1] for data in infer_data] target_transcripts = [data[1] for data in infer_data]

@ -173,43 +173,19 @@ class DeepSpeech2Model(object):
# run inference # run inference
return self._loss_inferer.infer(input=infer_data) return self._loss_inferer.infer(input=infer_data)
def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, def infer_probs_batch(self, infer_data, feeding_dict):
beam_size, cutoff_prob, cutoff_top_n, vocab_list, """Infer the prob matrices for a batch of speech utterances.
language_model_path, num_processes, feeding_dict):
"""Model inference. Infer the transcription for a batch of speech
utterances.
:param infer_data: List of utterances to infer, with each utterance :param infer_data: List of utterances to infer, with each utterance
consisting of a tuple of audio features and consisting of a tuple of audio features and
transcription text (empty string). transcription text (empty string).
:type infer_data: list :type infer_data: list
:param decoding_method: Decoding method name, 'ctc_greedy' or
'ctc_beam_search'.
:param decoding_method: string
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param language_model_path: Filepath for language model.
:type language_model_path: basestring|None
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:param feeding_dict: Feeding is a map of field name and tuple index :param feeding_dict: Feeding is a map of field name and tuple index
of the data that reader returns. of the data that reader returns.
:type feeding_dict: dict|list :type feeding_dict: dict|list
:return: List of transcription texts. :return: List of 2-D probability matrix, and each consists of prob
:rtype: List of basestring vectors for one speech utterancce.
:rtype: List of matrix
""" """
# define inferer # define inferer
if self._inferer == None: if self._inferer == None:
@ -227,23 +203,35 @@ class DeepSpeech2Model(object):
infer_results[start_pos[i]:start_pos[i + 1]] infer_results[start_pos[i]:start_pos[i + 1]]
for i in xrange(0, len(adapted_infer_data)) for i in xrange(0, len(adapted_infer_data))
] ]
# run decoder return probs_split
def infer_batch_greedy(self, probs_split, vocab_list):
"""
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:return: List of transcription texts.
:rtype: List of basestring
"""
results = [] results = []
if decoding_method == "ctc_greedy":
# best path decode
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder( output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list) probs_seq=probs, vocabulary=vocab_list)
results.append(output_transcription) results.append(output_transcription)
elif decoding_method == "ctc_beam_search": return results
# initialize external scorer
if self._ext_scorer == None: def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
self._loaded_lm_path = language_model_path vocab_list):
"""Initialize the external scorer.
"""
if language_model_path != '':
self.logger.info("begin to initialize the external scorer " self.logger.info("begin to initialize the external scorer "
"for decoding") "for decoding")
self._ext_scorer = Scorer(beam_alpha, beam_beta, self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path, vocab_list) language_model_path, vocab_list)
lm_char_based = self._ext_scorer.is_character_based() lm_char_based = self._ext_scorer.is_character_based()
lm_max_order = self._ext_scorer.get_max_order() lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size() lm_dict_size = self._ext_scorer.get_dict_size()
@ -251,10 +239,43 @@ class DeepSpeech2Model(object):
"is_character_based = %d," % lm_char_based + "is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order + " max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size) " dict_size = %d" % lm_dict_size)
self.logger.info("end initializing scorer. Start decoding ...") self.logger.info("end initializing scorer")
else: else:
self._ext_scorer = None
self.logger.info("no language model provided, "
"decoding by pure beam search without scorer.")
def infer_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes):
"""Model inference. Infer the transcription for a batch of speech
utterances.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:return: List of transcription texts.
:rtype: List of basestring
"""
if self._ext_scorer != None:
self._ext_scorer.reset_params(beam_alpha, beam_beta) self._ext_scorer.reset_params(beam_alpha, beam_beta)
assert self._loaded_lm_path == language_model_path
# beam search decode # beam search decode
num_processes = min(num_processes, len(probs_split)) num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch( beam_search_results = ctc_beam_search_decoder_batch(
@ -267,9 +288,6 @@ class DeepSpeech2Model(object):
cutoff_top_n=cutoff_top_n) cutoff_top_n=cutoff_top_n)
results = [result[0][1] for result in beam_search_results] results = [result[0][1] for result in beam_search_results]
else:
raise ValueError("Decoding method [%s] is not supported." %
decoding_method)
return results return results
def _adapt_feeding_dict(self, feeding_dict): def _adapt_feeding_dict(self, feeding_dict):

@ -90,22 +90,33 @@ def evaluate():
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
if args.decoding_method == "ctc_beam_search":
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
ds2_model.logger.info("start evaluation ...")
for infer_data in batch_reader(): for infer_data in batch_reader():
result_transcripts = ds2_model.infer_batch( probs_split = ds2_model.infer_probs_batch(
infer_data=infer_data, infer_data=infer_data,
decoding_method=args.decoding_method, feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcripts = ds2_model.infer_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
else:
result_transcripts = ds2_model.infer_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha, beam_alpha=args.alpha,
beam_beta=args.beta, beam_beta=args.beta,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n, cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list, vocab_list=vocab_list,
language_model_path=args.lang_model_path, num_processes=args.num_proc_bsearch)
num_processes=args.num_proc_bsearch,
feeding_dict=data_generator.feeding)
target_transcripts = [data[1] for data in infer_data] target_transcripts = [data[1] for data in infer_data]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors

@ -13,9 +13,7 @@ import logging
import paddle.v2 as paddle import paddle.v2 as paddle
import _init_paths import _init_paths
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from decoders.swig_wrapper import Scorer from model_utils.model import DeepSpeech2Model
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.model import deep_speech_v2_network
from utils.error_rate import char_errors, word_errors from utils.error_rate import char_errors, word_errors
from utils.utility import add_arguments, print_arguments from utils.utility import add_arguments, print_arguments
@ -88,40 +86,7 @@ def tune():
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=args.num_proc_data, num_threads=args.num_proc_data,
keep_transcription_text=True, keep_transcription_text=True)
num_conv_layers=args.num_conv_layers)
audio_data = paddle.layer.data(
name="audio_spectrogram",
type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
seq_offset_data = paddle.layer.data(
name='sequence_offset',
type=paddle.data_type.integer_value_sequence(1))
seq_len_data = paddle.layer.data(
name='sequence_length',
type=paddle.data_type.integer_value_sequence(1))
index_range_datas = []
for i in xrange(args.num_rnn_layers):
index_range_datas.append(
paddle.layer.data(
name='conv%d_index_range' % i,
type=paddle.data_type.dense_vector(6)))
output_probs, _ = deep_speech_v2_network(
audio_data=audio_data,
text_data=text_data,
seq_offset_data=seq_offset_data,
seq_len_data=seq_len_data,
index_range_datas=index_range_datas,
dict_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
use_gru=args.use_gru,
share_rnn_weights=args.share_rnn_weights)
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.tune_manifest, manifest_path=args.tune_manifest,
@ -129,35 +94,17 @@ def tune():
sortagrad=False, sortagrad=False,
shuffle_method=None) shuffle_method=None)
# load parameters ds2_model = DeepSpeech2Model(
if not os.path.isfile(args.model_path): vocab_size=data_generator.vocab_size,
raise IOError("Invaid model path: %s" % args.model_path) num_conv_layers=args.num_conv_layers,
parameters = paddle.parameters.Parameters.from_tar( num_rnn_layers=args.num_rnn_layers,
gzip.open(args.model_path)) rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
inferer = paddle.inference.Inference(
output_layer=output_probs, parameters=parameters)
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
# init logger
logger = logging.getLogger("")
logger.setLevel(level=logging.INFO)
# init external scorer
logger.info("begin to initialize the external scorer for tuning")
if not os.path.isfile(args.lang_model_path):
raise IOError("Invaid language model path: %s" % args.lang_model_path)
ext_scorer = Scorer(
alpha=args.alpha_from,
beta=args.beta_from,
model_path=args.lang_model_path,
vocabulary=vocab_list)
logger.info("language model: "
"is_character_based = %d," % ext_scorer.is_character_based() +
" max_order = %d," % ext_scorer.get_max_order() +
" dict_size = %d" % ext_scorer.get_dict_size())
logger.info("end initializing scorer. Start tuning ...")
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
# create grid for search # create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
@ -168,37 +115,32 @@ def tune():
err_sum = [0.0 for i in xrange(len(params_grid))] err_sum = [0.0 for i in xrange(len(params_grid))]
err_ave = [0.0 for i in xrange(len(params_grid))] err_ave = [0.0 for i in xrange(len(params_grid))]
num_ins, len_refs, cur_batch = 0, 0, 0 num_ins, len_refs, cur_batch = 0, 0, 0
# initialize external scorer
ds2_model.init_ext_scorer(args.alpha_from, args.beta_from,
args.lang_model_path, vocab_list)
## incremental tuning parameters over multiple batches ## incremental tuning parameters over multiple batches
ds2_model.logger.info("start tuning ...")
for infer_data in batch_reader(): for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches): if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break break
infer_results = inferer.infer(input=infer_data, probs_split = ds2_model.infer_probs_batch(
feeding=data_generator.feeding) infer_data=infer_data,
start_pos = [0] * (len(infer_data) + 1) feeding_dict=data_generator.feeding)
for i in xrange(len(infer_data)):
start_pos[i + 1] = start_pos[i] + infer_data[i][3][0]
probs_split = [
infer_results[start_pos[i]:start_pos[i + 1]]
for i in xrange(0, len(infer_data))
]
target_transcripts = [ data[1] for data in infer_data ] target_transcripts = [ data[1] for data in infer_data ]
num_ins += len(target_transcripts) num_ins += len(target_transcripts)
# grid search # grid search
for index, (alpha, beta) in enumerate(params_grid): for index, (alpha, beta) in enumerate(params_grid):
# reset alpha & beta result_transcripts = ds2_model.infer_batch_beam_search(
ext_scorer.reset_params(alpha, beta)
beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split, probs_split=probs_split,
vocabulary=vocab_list, beam_alpha=alpha,
beam_beta=beta,
beam_size=args.beam_size, beam_size=args.beam_size,
num_processes=args.num_proc_bsearch,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n, cutoff_top_n=args.cutoff_top_n,
ext_scoring_func=ext_scorer, ) vocab_list=vocab_list,
num_processes=args.num_proc_bsearch)
result_transcripts = [res[0][1] for res in beam_search_results]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
err_sum[index] += errors err_sum[index] += errors
@ -235,7 +177,7 @@ def tune():
% (cur_batch, "%.3f" % params_grid[min_index][0], % (cur_batch, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1])) "%.3f" % params_grid[min_index][1]))
logger.info("finish tuning") ds2_model.logger.info("finish tuning")
def main(): def main():

Loading…
Cancel
Save