From ae796a9dab6424cee9b29efed5223a96736f0611 Mon Sep 17 00:00:00 2001
From: Yibing Liu <liuyibing01@baidu.com>
Date: Mon, 4 Dec 2017 13:39:56 +0800
Subject: [PATCH] Correct the error rate's computation for multiple sentences

---
 test.py       | 14 ++++++++------
 tools/tune.py | 18 ++++++++++++------
 2 files changed, 20 insertions(+), 12 deletions(-)

diff --git a/test.py b/test.py
index 53f7e17af..224cea9b6 100644
--- a/test.py
+++ b/test.py
@@ -8,7 +8,7 @@ import functools
 import paddle.v2 as paddle
 from data_utils.data import DataGenerator
 from model_utils.model import DeepSpeech2Model
-from utils.error_rate import wer, cer
+from utils.error_rate import char_errors, word_errors
 from utils.utility import add_arguments, print_arguments
 
 parser = argparse.ArgumentParser(description=__doc__)
@@ -91,8 +91,8 @@ def evaluate():
     # decoders only accept string encoded in utf-8
     vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
 
-    error_rate_func = cer if args.error_rate_type == 'cer' else wer
-    error_sum, num_ins = 0.0, 0
+    errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
+    errors_sum, len_refs, num_ins = 0.0, 0, 0
     for infer_data in batch_reader():
         result_transcripts = ds2_model.infer_batch(
             infer_data=infer_data,
@@ -108,12 +108,14 @@ def evaluate():
             feeding_dict=data_generator.feeding)
         target_transcripts = [data[1] for data in infer_data]
         for target, result in zip(target_transcripts, result_transcripts):
-            error_sum += error_rate_func(target, result)
+            errors, len_ref = errors_func(target, result)
+            errors_sum += errors
+            len_refs += len_ref
             num_ins += 1
         print("Error rate [%s] (%d/?) = %f" %
-              (args.error_rate_type, num_ins, error_sum / num_ins))
+              (args.error_rate_type, num_ins, errors_sum / len_refs))
     print("Final error rate [%s] (%d/%d) = %f" %
-          (args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
+          (args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
 
     ds2_model.logger.info("finish evaluation")
 
diff --git a/tools/tune.py b/tools/tune.py
index 47abf1413..b13233195 100644
--- a/tools/tune.py
+++ b/tools/tune.py
@@ -16,7 +16,7 @@ from data_utils.data import DataGenerator
 from decoders.swig_wrapper import Scorer
 from decoders.swig_wrapper import ctc_beam_search_decoder_batch
 from model_utils.model import deep_speech_v2_network
-from utils.error_rate import wer, cer
+from utils.error_rate import char_errors, word_errors
 from utils.utility import add_arguments, print_arguments
 
 parser = argparse.ArgumentParser(description=__doc__)
@@ -158,7 +158,7 @@ def tune():
                 " dict_size = %d" % ext_scorer.get_dict_size())
     logger.info("end initializing scorer. Start tuning ...")
 
-    error_rate_func = cer if args.error_rate_type == 'cer' else wer
+    errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
     # create grid for search
     cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
     cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
@@ -167,7 +167,7 @@ def tune():
 
     err_sum = [0.0 for i in xrange(len(params_grid))]
     err_ave = [0.0 for i in xrange(len(params_grid))]
-    num_ins, cur_batch = 0, 0
+    num_ins, len_refs, cur_batch = 0, 0, 0
     ## incremental tuning parameters over multiple batches
     for infer_data in batch_reader():
         if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
@@ -200,8 +200,14 @@ def tune():
 
             result_transcripts = [res[0][1] for res in beam_search_results]
             for target, result in zip(target_transcripts, result_transcripts):
-                err_sum[index] += error_rate_func(target, result)
-            err_ave[index] = err_sum[index] / num_ins
+                errors, len_ref = errors_func(target, result)
+                err_sum[index] += errors
+                # accumulate the length of references of every batch
+                # in the first iteration
+                if args.alpha_from == alpha and args.beta_from == beta:
+                    len_refs += len_ref
+
+            err_ave[index] = err_sum[index] / len_refs
             if index % 2 == 0:
                 sys.stdout.write('.')
                 sys.stdout.flush()
@@ -226,7 +232,7 @@ def tune():
     err_ave_min = min(err_ave)
     min_index = err_ave.index(err_ave_min)
     print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
-            % (args.num_batches, "%.3f" % params_grid[min_index][0],
+            % (cur_batch, "%.3f" % params_grid[min_index][0],
               "%.3f" % params_grid[min_index][1]))
 
     logger.info("finish tuning")