avoid repeated infer for same batch in the tuning of DS2

pull/2/head
Yibing Liu 7 years ago
parent 6cac602bc5
commit 1db71425c2

@ -7,10 +7,14 @@ import sys
import numpy as np import numpy as np
import argparse import argparse
import functools import functools
import gzip
import logging
import paddle.v2 as paddle import paddle.v2 as paddle
import _init_paths import _init_paths
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.model import deep_speech_v2_network
from utils.error_rate import wer, cer from utils.error_rate import wer, cer
from utils.utility import add_arguments, print_arguments from utils.utility import add_arguments, print_arguments
@ -66,6 +70,9 @@ add_arg('specgram_type', str,
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
def tune(): def tune():
"""Tune parameters alpha and beta incrementally.""" """Tune parameters alpha and beta incrementally."""
if not args.num_alphas >= 0: if not args.num_alphas >= 0:
@ -79,29 +86,55 @@ def tune():
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=1) num_threads=1)
audio_data = paddle.layer.data(
name="audio_spectrogram",
type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
output_probs, _ = deep_speech_v2_network(
audio_data=audio_data,
text_data=text_data,
dict_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
use_gru=args.use_gru,
share_rnn_weights=args.share_rnn_weights)
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.tune_manifest, manifest_path=args.tune_manifest,
batch_size=args.batch_size, batch_size=args.batch_size,
sortagrad=False, sortagrad=False,
shuffle_method=None) shuffle_method=None)
tune_data = batch_reader().next()
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in tune_data
]
ds2_model = DeepSpeech2Model( # load parameters
vocab_size=data_generator.vocab_size, parameters = paddle.parameters.Parameters.from_tar(
num_conv_layers=args.num_conv_layers, gzip.open(args.model_path))
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)
inferer = paddle.inference.Inference(
output_layer=output_probs, parameters=parameters)
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
# init logger
logger = logging.getLogger("")
logger.setLevel(level=logging.INFO)
# init external scorer
logger.info("begin to initialize the external scorer for tuning")
ext_scorer = Scorer(
alpha=args.alpha_from,
beta=args.beta_from,
model_path=args.lang_model_path,
vocabulary=vocab_list)
logger.info("language model: "
"is_character_based = %d," % ext_scorer.is_character_based() +
" max_order = %d," % ext_scorer.get_max_order() +
" dict_size = %d" % ext_scorer.get_dict_size())
logger.info("end initializing scorer. Start tuning ...")
error_rate_func = cer if args.error_rate_type == 'cer' else wer error_rate_func = cer if args.error_rate_type == 'cer' else wer
# create grid for search # create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
@ -116,6 +149,13 @@ def tune():
for infer_data in batch_reader(): for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches): if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break break
infer_results = inferer.infer(input=infer_data)
num_steps = len(infer_results) // len(infer_data)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(len(infer_data))
]
target_transcripts = [ target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript]) ''.join([data_generator.vocab_list[token] for token in transcript])
@ -125,18 +165,18 @@ def tune():
num_ins += len(target_transcripts) num_ins += len(target_transcripts)
# grid search # grid search
for index, (alpha, beta) in enumerate(params_grid): for index, (alpha, beta) in enumerate(params_grid):
result_transcripts = ds2_model.infer_batch( # reset alpha & beta
infer_data=infer_data, ext_scorer.reset_params(alpha, beta)
decoding_method='ctc_beam_search', beam_search_results = ctc_beam_search_decoder_batch(
beam_alpha=alpha, probs_split=probs_split,
beam_beta=beta, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
num_processes=args.num_proc_bsearch,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n, cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list, ext_scoring_func=ext_scorer, )
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
result_transcripts = [res[0][1] for res in beam_search_results]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
err_sum[index] += error_rate_func(target, result) err_sum[index] += error_rate_func(target, result)
err_ave[index] = err_sum[index] / num_ins err_ave[index] = err_sum[index] / num_ins
@ -167,7 +207,7 @@ def tune():
% (args.num_batches, "%.3f" % params_grid[min_index][0], % (args.num_batches, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1])) "%.3f" % params_grid[min_index][1]))
ds2_model.logger.info("finish inference") logger.info("finish tuning")
def main(): def main():

Loading…
Cancel
Save