From 92eacf548bf5ca278a2ad741dd9c901ca6d23a8f Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Mon, 31 Jul 2017 21:57:07 +0800 Subject: [PATCH] Update default config params and result display for evaluator.py and infer.py for DS2. --- evaluate.py | 26 ++++++++++++++++++-------- infer.py | 9 +++++++-- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/evaluate.py b/evaluate.py index 19eabf4e..1d758687 100644 --- a/evaluate.py +++ b/evaluate.py @@ -4,6 +4,7 @@ from __future__ import division from __future__ import print_function import distutils.util +import sys import argparse import gzip import paddle.v2 as paddle @@ -12,13 +13,19 @@ from model import deep_speech2 from decoder import * from lm.lm_scorer import LmScorer from error_rate import wer +import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--batch_size", - default=100, + default=128, type=int, help="Minibatch size for evaluation. (default: %(default)s)") +parser.add_argument( + "--trainer_count", + default=8, + type=int, + help="Trainer number. (default: %(default)s)") parser.add_argument( "--num_conv_layers", default=2, @@ -58,8 +65,8 @@ parser.add_argument( "--decode_method", default='beam_search', type=str, - help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" -) + help="Method for ctc decoding, best_path or beam_search. " + "(default: %(default)s)") parser.add_argument( "--language_model_path", default="lm/data/common_crawl_00.prune01111.trie.klm", @@ -67,12 +74,12 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", - default=0.26, + default=0.36, type=float, help="Parameter associated with language model. (default: %(default)f)") parser.add_argument( "--beta", - default=0.1, + default=0.25, type=float, help="Parameter associated with word count. (default: %(default)f)") parser.add_argument( @@ -191,7 +198,7 @@ def evaluate(): blank_id=len(data_generator.vocab_list), num_processes=args.num_processes_beam_search, ext_scoring_func=ext_scorer, - cutoff_prob=args.cutoff_prob, ) + cutoff_prob=args.cutoff_prob) for i, beam_search_result in enumerate(beam_search_results): wer_sum += wer(target_transcription[i], beam_search_result[0][1]) @@ -199,12 +206,15 @@ def evaluate(): else: raise ValueError("Decoding method [%s] is not supported." % decode_method) + print("WER (%d/?) = %f" % (wer_counter, wer_sum / wer_counter)) - print("Final WER = %f" % (wer_sum / wer_counter)) + print("Final WER (%d/%d) = %f" % (wer_counter, wer_counter, + wer_sum / wer_counter)) def main(): - paddle.init(use_gpu=args.use_gpu, trainer_count=1) + utils.print_arguments(args) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) evaluate() diff --git a/infer.py b/infer.py index 81752630..ad3fdc4d 100644 --- a/infer.py +++ b/infer.py @@ -57,6 +57,11 @@ parser.add_argument( type=str, help="Feature type of audio data: 'linear' (power spectrum)" " or 'mfcc'. (default: %(default)s)") +parser.add_argument( + "--trainer_count", + default=8, + type=int, + help="Trainer number. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -208,7 +213,7 @@ def infer(): wer_cur = wer(target_transcription[i], beam_search_result[0][1]) wer_sum += wer_cur wer_counter += 1 - print("cur wer = %f , average wer = %f" % + print("Current WER = %f , Average WER = %f" % (wer_cur, wer_sum / wer_counter)) else: raise ValueError("Decoding method [%s] is not supported." % @@ -217,7 +222,7 @@ def infer(): def main(): utils.print_arguments(args) - paddle.init(use_gpu=args.use_gpu, trainer_count=1) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) infer()