From 4b3f768df7d165467fbdc44e6d91fae4a1715dea Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 9 Aug 2017 20:03:53 +0800 Subject: [PATCH] Simplify description and codes. --- evaluate.py | 23 ++++++++--------------- infer.py | 19 ++++++------------- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/evaluate.py b/evaluate.py index 7406e0bdd..82dcec3c2 100644 --- a/evaluate.py +++ b/evaluate.py @@ -9,8 +9,7 @@ import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator from model import DeepSpeech2Model -from error_rate import wer -from error_rate import cer +from error_rate import wer, cer import utils parser = argparse.ArgumentParser(description=__doc__) @@ -117,8 +116,8 @@ parser.add_argument( default='wer', choices=['wer', 'cer'], type=str, - help="There are total two error rate types including wer and cer. wer " - "represents for word error rate while cer for character error rate. " + help="Error rate type for evaluation. 'wer' for word error rate and 'cer' " + "for character error rate. " "(default: %(default)s)") args = parser.parse_args() @@ -145,13 +144,7 @@ def evaluate(): rnn_layer_size=args.rnn_layer_size, pretrained_model_path=args.model_filepath) - if args.error_rate_type == 'wer': - error_rate_func = wer - error_rate_info = 'WER' - else: - error_rate_func = cer - error_rate_info = 'CER' - + error_rate_func = cer if args.error_rate_type == 'cer' else wer error_sum, num_ins = 0.0, 0 for infer_data in batch_reader(): result_transcripts = ds2_model.infer_batch( @@ -171,10 +164,10 @@ def evaluate(): for target, result in zip(target_transcripts, result_transcripts): error_sum += error_rate_func(target, result) num_ins += 1 - print("%s (%d/?) = %f" % \ - (error_rate_info, num_ins, error_sum / num_ins)) - print("Final %s (%d/%d) = %f" % \ - (error_rate_info, num_ins, num_ins, error_sum / num_ins)) + print("Error rate [%s] (%d/?) = %f" % + (args.error_rate_type, num_ins, error_sum / num_ins)) + print("Final error rate [%s] (%d/%d) = %f" % + (args.error_rate_type, num_ins, num_ins, error_sum / num_ins)) def main(): diff --git a/infer.py b/infer.py index 3aba847e7..43643cde7 100644 --- a/infer.py +++ b/infer.py @@ -9,8 +9,7 @@ import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator from model import DeepSpeech2Model -from error_rate import wer -from error_rate import cer +from error_rate import wer, cer import utils parser = argparse.ArgumentParser(description=__doc__) @@ -117,8 +116,8 @@ parser.add_argument( default='wer', choices=['wer', 'cer'], type=str, - help="There are total two error rate types including wer and cer. wer " - "represents for word error rate while cer for character error rate. " + help="Error rate type for evaluation. 'wer' for word error rate and 'cer' " + "for character error rate. " "(default: %(default)s)") args = parser.parse_args() @@ -156,13 +155,7 @@ def infer(): language_model_path=args.language_model_path, num_processes=args.num_processes_beam_search) - if args.error_rate_type == 'wer': - error_rate_func = wer - error_rate_info = 'wer' - else: - error_rate_func = cer - error_rate_info = 'cer' - + error_rate_func = cer if args.error_rate_type == 'cer' else wer target_transcripts = [ ''.join([data_generator.vocab_list[token] for token in transcript]) for _, transcript in infer_data @@ -170,8 +163,8 @@ def infer(): for target, result in zip(target_transcripts, result_transcripts): print("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) - print("Current %s = %f" % \ - (error_rate_info, error_rate_func(target, result))) + print("Current error rate [%s] = %f" % + (args.error_rate_type, error_rate_func(target, result))) def main():