diff --git a/evaluate.py b/evaluate.py index 592b7b52..7406e0bd 100644 --- a/evaluate.py +++ b/evaluate.py @@ -10,6 +10,7 @@ import paddle.v2 as paddle from data_utils.data import DataGenerator from model import DeepSpeech2Model from error_rate import wer +from error_rate import cer import utils parser = argparse.ArgumentParser(description=__doc__) @@ -111,6 +112,14 @@ parser.add_argument( default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--error_rate_type", + default='wer', + choices=['wer', 'cer'], + type=str, + help="There are total two error rate types including wer and cer. wer " + "represents for word error rate while cer for character error rate. " + "(default: %(default)s)") args = parser.parse_args() @@ -136,7 +145,14 @@ def evaluate(): rnn_layer_size=args.rnn_layer_size, pretrained_model_path=args.model_filepath) - wer_sum, num_ins = 0.0, 0 + if args.error_rate_type == 'wer': + error_rate_func = wer + error_rate_info = 'WER' + else: + error_rate_func = cer + error_rate_info = 'CER' + + error_sum, num_ins = 0.0, 0 for infer_data in batch_reader(): result_transcripts = ds2_model.infer_batch( infer_data=infer_data, @@ -153,10 +169,12 @@ def evaluate(): for _, transcript in infer_data ] for target, result in zip(target_transcripts, result_transcripts): - wer_sum += wer(target, result) + error_sum += error_rate_func(target, result) num_ins += 1 - print("WER (%d/?) = %f" % (num_ins, wer_sum / num_ins)) - print("Final WER (%d/%d) = %f" % (num_ins, num_ins, wer_sum / num_ins)) + print("%s (%d/?) = %f" % \ + (error_rate_info, num_ins, error_sum / num_ins)) + print("Final %s (%d/%d) = %f" % \ + (error_rate_info, num_ins, num_ins, error_sum / num_ins)) def main(): diff --git a/infer.py b/infer.py index df5953e5..3aba847e 100644 --- a/infer.py +++ b/infer.py @@ -10,6 +10,7 @@ import paddle.v2 as paddle from data_utils.data import DataGenerator from model import DeepSpeech2Model from error_rate import wer +from error_rate import cer import utils parser = argparse.ArgumentParser(description=__doc__) @@ -111,6 +112,14 @@ parser.add_argument( type=float, help="The cutoff probability of pruning" "in beam search. (default: %(default)f)") +parser.add_argument( + "--error_rate_type", + default='wer', + choices=['wer', 'cer'], + type=str, + help="There are total two error rate types including wer and cer. wer " + "represents for word error rate while cer for character error rate. " + "(default: %(default)s)") args = parser.parse_args() @@ -147,6 +156,13 @@ def infer(): language_model_path=args.language_model_path, num_processes=args.num_processes_beam_search) + if args.error_rate_type == 'wer': + error_rate_func = wer + error_rate_info = 'wer' + else: + error_rate_func = cer + error_rate_info = 'cer' + target_transcripts = [ ''.join([data_generator.vocab_list[token] for token in transcript]) for _, transcript in infer_data @@ -154,7 +170,8 @@ def infer(): for target, result in zip(target_transcripts, result_transcripts): print("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) - print("Current wer = %f" % wer(target, result)) + print("Current %s = %f" % \ + (error_rate_info, error_rate_func(target, result))) def main(): diff --git a/model.py b/model.py index 2eb7c359..e2f2903b 100644 --- a/model.py +++ b/model.py @@ -185,7 +185,7 @@ class DeepSpeech2Model(object): # best path decode for i, probs in enumerate(probs_split): output_transcription = ctc_best_path_decoder( - probs_seq=probs, vocabulary=data_generator.vocab_list) + probs_seq=probs, vocabulary=vocab_list) results.append(output_transcription) elif decode_method == "beam_search": # initialize external scorer