Merge pull request #51 from kuke/fix_err_rate

Correct the error rate's computation for multiple sentences
pull/58/head
Yibing Liu 7 years ago committed by GitHub
commit a835d41206
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,7 +8,7 @@ import functools
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model
from utils.error_rate import wer, cer
from utils.error_rate import char_errors, word_errors
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
@ -91,8 +91,8 @@ def evaluate():
# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_sum, num_ins = 0.0, 0
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
errors_sum, len_refs, num_ins = 0.0, 0, 0
for infer_data in batch_reader():
result_transcripts = ds2_model.infer_batch(
infer_data=infer_data,
@ -108,12 +108,14 @@ def evaluate():
feeding_dict=data_generator.feeding)
target_transcripts = [data[1] for data in infer_data]
for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result)
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
print("Error rate [%s] (%d/?) = %f" %
(args.error_rate_type, num_ins, error_sum / num_ins))
(args.error_rate_type, num_ins, errors_sum / len_refs))
print("Final error rate [%s] (%d/%d) = %f" %
(args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
(args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
ds2_model.logger.info("finish evaluation")

@ -16,7 +16,7 @@ from data_utils.data import DataGenerator
from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.model import deep_speech_v2_network
from utils.error_rate import wer, cer
from utils.error_rate import char_errors, word_errors
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
@ -158,7 +158,7 @@ def tune():
" dict_size = %d" % ext_scorer.get_dict_size())
logger.info("end initializing scorer. Start tuning ...")
error_rate_func = cer if args.error_rate_type == 'cer' else wer
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
# create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
@ -167,7 +167,7 @@ def tune():
err_sum = [0.0 for i in xrange(len(params_grid))]
err_ave = [0.0 for i in xrange(len(params_grid))]
num_ins, cur_batch = 0, 0
num_ins, len_refs, cur_batch = 0, 0, 0
## incremental tuning parameters over multiple batches
for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
@ -200,8 +200,14 @@ def tune():
result_transcripts = [res[0][1] for res in beam_search_results]
for target, result in zip(target_transcripts, result_transcripts):
err_sum[index] += error_rate_func(target, result)
err_ave[index] = err_sum[index] / num_ins
errors, len_ref = errors_func(target, result)
err_sum[index] += errors
# accumulate the length of references of every batch
# in the first iteration
if args.alpha_from == alpha and args.beta_from == beta:
len_refs += len_ref
err_ave[index] = err_sum[index] / len_refs
if index % 2 == 0:
sys.stdout.write('.')
sys.stdout.flush()
@ -226,7 +232,7 @@ def tune():
err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
% (args.num_batches, "%.3f" % params_grid[min_index][0],
% (cur_batch, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1]))
logger.info("finish tuning")

Loading…
Cancel
Save