diff --git a/test.py b/test.py index 53f7e17af..224cea9b6 100644 --- a/test.py +++ b/test.py @@ -8,7 +8,7 @@ import functools import paddle.v2 as paddle from data_utils.data import DataGenerator from model_utils.model import DeepSpeech2Model -from utils.error_rate import wer, cer +from utils.error_rate import char_errors, word_errors from utils.utility import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) @@ -91,8 +91,8 @@ def evaluate(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] - error_rate_func = cer if args.error_rate_type == 'cer' else wer - error_sum, num_ins = 0.0, 0 + errors_func = char_errors if args.error_rate_type == 'cer' else word_errors + errors_sum, len_refs, num_ins = 0.0, 0, 0 for infer_data in batch_reader(): result_transcripts = ds2_model.infer_batch( infer_data=infer_data, @@ -108,12 +108,14 @@ def evaluate(): feeding_dict=data_generator.feeding) target_transcripts = [data[1] for data in infer_data] for target, result in zip(target_transcripts, result_transcripts): - error_sum += error_rate_func(target, result) + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref num_ins += 1 print("Error rate [%s] (%d/?) = %f" % - (args.error_rate_type, num_ins, error_sum / num_ins)) + (args.error_rate_type, num_ins, errors_sum / len_refs)) print("Final error rate [%s] (%d/%d) = %f" % - (args.error_rate_type, num_ins, num_ins, error_sum / num_ins)) + (args.error_rate_type, num_ins, num_ins, errors_sum / len_refs)) ds2_model.logger.info("finish evaluation") diff --git a/tools/tune.py b/tools/tune.py index 47abf1413..b13233195 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -16,7 +16,7 @@ from data_utils.data import DataGenerator from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import ctc_beam_search_decoder_batch from model_utils.model import deep_speech_v2_network -from utils.error_rate import wer, cer +from utils.error_rate import char_errors, word_errors from utils.utility import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) @@ -158,7 +158,7 @@ def tune(): " dict_size = %d" % ext_scorer.get_dict_size()) logger.info("end initializing scorer. Start tuning ...") - error_rate_func = cer if args.error_rate_type == 'cer' else wer + errors_func = char_errors if args.error_rate_type == 'cer' else word_errors # create grid for search cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) @@ -167,7 +167,7 @@ def tune(): err_sum = [0.0 for i in xrange(len(params_grid))] err_ave = [0.0 for i in xrange(len(params_grid))] - num_ins, cur_batch = 0, 0 + num_ins, len_refs, cur_batch = 0, 0, 0 ## incremental tuning parameters over multiple batches for infer_data in batch_reader(): if (args.num_batches >= 0) and (cur_batch >= args.num_batches): @@ -200,8 +200,14 @@ def tune(): result_transcripts = [res[0][1] for res in beam_search_results] for target, result in zip(target_transcripts, result_transcripts): - err_sum[index] += error_rate_func(target, result) - err_ave[index] = err_sum[index] / num_ins + errors, len_ref = errors_func(target, result) + err_sum[index] += errors + # accumulate the length of references of every batch + # in the first iteration + if args.alpha_from == alpha and args.beta_from == beta: + len_refs += len_ref + + err_ave[index] = err_sum[index] / len_refs if index % 2 == 0: sys.stdout.write('.') sys.stdout.flush() @@ -226,7 +232,7 @@ def tune(): err_ave_min = min(err_ave) min_index = err_ave.index(err_ave_min) print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" - % (args.num_batches, "%.3f" % params_grid[min_index][0], + % (cur_batch, "%.3f" % params_grid[min_index][0], "%.3f" % params_grid[min_index][1])) logger.info("finish tuning")