|
|
|
@ -16,7 +16,7 @@ from data_utils.data import DataGenerator
|
|
|
|
|
from decoders.swig_wrapper import Scorer
|
|
|
|
|
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
|
|
|
|
|
from model_utils.model import deep_speech_v2_network
|
|
|
|
|
from utils.error_rate import wer, cer
|
|
|
|
|
from utils.error_rate import char_errors, word_errors
|
|
|
|
|
from utils.utility import add_arguments, print_arguments
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
|
@ -158,7 +158,7 @@ def tune():
|
|
|
|
|
" dict_size = %d" % ext_scorer.get_dict_size())
|
|
|
|
|
logger.info("end initializing scorer. Start tuning ...")
|
|
|
|
|
|
|
|
|
|
error_rate_func = cer if args.error_rate_type == 'cer' else wer
|
|
|
|
|
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
|
|
|
|
|
# create grid for search
|
|
|
|
|
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
|
|
|
|
|
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
|
|
|
|
@ -167,7 +167,7 @@ def tune():
|
|
|
|
|
|
|
|
|
|
err_sum = [0.0 for i in xrange(len(params_grid))]
|
|
|
|
|
err_ave = [0.0 for i in xrange(len(params_grid))]
|
|
|
|
|
num_ins, cur_batch = 0, 0
|
|
|
|
|
num_ins, len_refs, cur_batch = 0, 0, 0
|
|
|
|
|
## incremental tuning parameters over multiple batches
|
|
|
|
|
for infer_data in batch_reader():
|
|
|
|
|
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
|
|
|
|
@ -200,8 +200,14 @@ def tune():
|
|
|
|
|
|
|
|
|
|
result_transcripts = [res[0][1] for res in beam_search_results]
|
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
|
err_sum[index] += error_rate_func(target, result)
|
|
|
|
|
err_ave[index] = err_sum[index] / num_ins
|
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
|
|
err_sum[index] += errors
|
|
|
|
|
# accumulate the length of references of every batch
|
|
|
|
|
# in the first iteration
|
|
|
|
|
if args.alpha_from == alpha and args.beta_from == beta:
|
|
|
|
|
len_refs += len_ref
|
|
|
|
|
|
|
|
|
|
err_ave[index] = err_sum[index] / len_refs
|
|
|
|
|
if index % 2 == 0:
|
|
|
|
|
sys.stdout.write('.')
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
@ -226,7 +232,7 @@ def tune():
|
|
|
|
|
err_ave_min = min(err_ave)
|
|
|
|
|
min_index = err_ave.index(err_ave_min)
|
|
|
|
|
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
|
|
|
|
|
% (args.num_batches, "%.3f" % params_grid[min_index][0],
|
|
|
|
|
% (cur_batch, "%.3f" % params_grid[min_index][0],
|
|
|
|
|
"%.3f" % params_grid[min_index][1]))
|
|
|
|
|
|
|
|
|
|
logger.info("finish tuning")
|
|
|
|
|