Merge pull request #51 from kuke/fix_err_rate

Correct the error rate's computation for multiple sentences
pull/58/head
Yibing Liu 7 years ago committed by GitHub
commit a835d41206
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,7 +8,7 @@ import functools
import paddle.v2 as paddle import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model from model_utils.model import DeepSpeech2Model
from utils.error_rate import wer, cer from utils.error_rate import char_errors, word_errors
from utils.utility import add_arguments, print_arguments from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
@ -91,8 +91,8 @@ def evaluate():
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
error_rate_func = cer if args.error_rate_type == 'cer' else wer errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
error_sum, num_ins = 0.0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
for infer_data in batch_reader(): for infer_data in batch_reader():
result_transcripts = ds2_model.infer_batch( result_transcripts = ds2_model.infer_batch(
infer_data=infer_data, infer_data=infer_data,
@ -108,12 +108,14 @@ def evaluate():
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
target_transcripts = [data[1] for data in infer_data] target_transcripts = [data[1] for data in infer_data]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1 num_ins += 1
print("Error rate [%s] (%d/?) = %f" % print("Error rate [%s] (%d/?) = %f" %
(args.error_rate_type, num_ins, error_sum / num_ins)) (args.error_rate_type, num_ins, errors_sum / len_refs))
print("Final error rate [%s] (%d/%d) = %f" % print("Final error rate [%s] (%d/%d) = %f" %
(args.error_rate_type, num_ins, num_ins, error_sum / num_ins)) (args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
ds2_model.logger.info("finish evaluation") ds2_model.logger.info("finish evaluation")

@ -16,7 +16,7 @@ from data_utils.data import DataGenerator
from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_beam_search_decoder_batch from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.model import deep_speech_v2_network from model_utils.model import deep_speech_v2_network
from utils.error_rate import wer, cer from utils.error_rate import char_errors, word_errors
from utils.utility import add_arguments, print_arguments from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
@ -158,7 +158,7 @@ def tune():
" dict_size = %d" % ext_scorer.get_dict_size()) " dict_size = %d" % ext_scorer.get_dict_size())
logger.info("end initializing scorer. Start tuning ...") logger.info("end initializing scorer. Start tuning ...")
error_rate_func = cer if args.error_rate_type == 'cer' else wer errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
# create grid for search # create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
@ -167,7 +167,7 @@ def tune():
err_sum = [0.0 for i in xrange(len(params_grid))] err_sum = [0.0 for i in xrange(len(params_grid))]
err_ave = [0.0 for i in xrange(len(params_grid))] err_ave = [0.0 for i in xrange(len(params_grid))]
num_ins, cur_batch = 0, 0 num_ins, len_refs, cur_batch = 0, 0, 0
## incremental tuning parameters over multiple batches ## incremental tuning parameters over multiple batches
for infer_data in batch_reader(): for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches): if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
@ -200,8 +200,14 @@ def tune():
result_transcripts = [res[0][1] for res in beam_search_results] result_transcripts = [res[0][1] for res in beam_search_results]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
err_sum[index] += error_rate_func(target, result) errors, len_ref = errors_func(target, result)
err_ave[index] = err_sum[index] / num_ins err_sum[index] += errors
# accumulate the length of references of every batch
# in the first iteration
if args.alpha_from == alpha and args.beta_from == beta:
len_refs += len_ref
err_ave[index] = err_sum[index] / len_refs
if index % 2 == 0: if index % 2 == 0:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
@ -226,7 +232,7 @@ def tune():
err_ave_min = min(err_ave) err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min) min_index = err_ave.index(err_ave_min)
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
% (args.num_batches, "%.3f" % params_grid[min_index][0], % (cur_batch, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1])) "%.3f" % params_grid[min_index][1]))
logger.info("finish tuning") logger.info("finish tuning")

Loading…
Cancel
Save