Simplify description and codes.

pull/2/head
yangyaming 7 years ago
parent 5ef300f3f0
commit 4b3f768df7

@ -9,8 +9,7 @@ import multiprocessing
import paddle.v2 as paddle import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import DeepSpeech2Model from model import DeepSpeech2Model
from error_rate import wer from error_rate import wer, cer
from error_rate import cer
import utils import utils
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
@ -117,8 +116,8 @@ parser.add_argument(
default='wer', default='wer',
choices=['wer', 'cer'], choices=['wer', 'cer'],
type=str, type=str,
help="There are total two error rate types including wer and cer. wer " help="Error rate type for evaluation. 'wer' for word error rate and 'cer' "
"represents for word error rate while cer for character error rate. " "for character error rate. "
"(default: %(default)s)") "(default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
@ -145,13 +144,7 @@ def evaluate():
rnn_layer_size=args.rnn_layer_size, rnn_layer_size=args.rnn_layer_size,
pretrained_model_path=args.model_filepath) pretrained_model_path=args.model_filepath)
if args.error_rate_type == 'wer': error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_rate_func = wer
error_rate_info = 'WER'
else:
error_rate_func = cer
error_rate_info = 'CER'
error_sum, num_ins = 0.0, 0 error_sum, num_ins = 0.0, 0
for infer_data in batch_reader(): for infer_data in batch_reader():
result_transcripts = ds2_model.infer_batch( result_transcripts = ds2_model.infer_batch(
@ -171,10 +164,10 @@ def evaluate():
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result) error_sum += error_rate_func(target, result)
num_ins += 1 num_ins += 1
print("%s (%d/?) = %f" % \ print("Error rate [%s] (%d/?) = %f" %
(error_rate_info, num_ins, error_sum / num_ins)) (args.error_rate_type, num_ins, error_sum / num_ins))
print("Final %s (%d/%d) = %f" % \ print("Final error rate [%s] (%d/%d) = %f" %
(error_rate_info, num_ins, num_ins, error_sum / num_ins)) (args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
def main(): def main():

@ -9,8 +9,7 @@ import multiprocessing
import paddle.v2 as paddle import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import DeepSpeech2Model from model import DeepSpeech2Model
from error_rate import wer from error_rate import wer, cer
from error_rate import cer
import utils import utils
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
@ -117,8 +116,8 @@ parser.add_argument(
default='wer', default='wer',
choices=['wer', 'cer'], choices=['wer', 'cer'],
type=str, type=str,
help="There are total two error rate types including wer and cer. wer " help="Error rate type for evaluation. 'wer' for word error rate and 'cer' "
"represents for word error rate while cer for character error rate. " "for character error rate. "
"(default: %(default)s)") "(default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
@ -156,13 +155,7 @@ def infer():
language_model_path=args.language_model_path, language_model_path=args.language_model_path,
num_processes=args.num_processes_beam_search) num_processes=args.num_processes_beam_search)
if args.error_rate_type == 'wer': error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_rate_func = wer
error_rate_info = 'wer'
else:
error_rate_func = cer
error_rate_info = 'cer'
target_transcripts = [ target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript]) ''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data for _, transcript in infer_data
@ -170,8 +163,8 @@ def infer():
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" % print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result)) (target, result))
print("Current %s = %f" % \ print("Current error rate [%s] = %f" %
(error_rate_info, error_rate_func(target, result))) (args.error_rate_type, error_rate_func(target, result)))
def main(): def main():

Loading…
Cancel
Save