Simplify description and codes.

pull/2/head
yangyaming 7 years ago
parent 5ef300f3f0
commit 4b3f768df7

@ -9,8 +9,7 @@ import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
from error_rate import wer
from error_rate import cer
from error_rate import wer, cer
import utils
parser = argparse.ArgumentParser(description=__doc__)
@ -117,8 +116,8 @@ parser.add_argument(
default='wer',
choices=['wer', 'cer'],
type=str,
help="There are total two error rate types including wer and cer. wer "
"represents for word error rate while cer for character error rate. "
help="Error rate type for evaluation. 'wer' for word error rate and 'cer' "
"for character error rate. "
"(default: %(default)s)")
args = parser.parse_args()
@ -145,13 +144,7 @@ def evaluate():
rnn_layer_size=args.rnn_layer_size,
pretrained_model_path=args.model_filepath)
if args.error_rate_type == 'wer':
error_rate_func = wer
error_rate_info = 'WER'
else:
error_rate_func = cer
error_rate_info = 'CER'
error_rate_func = cer if args.error_rate_type == 'cer' else wer
error_sum, num_ins = 0.0, 0
for infer_data in batch_reader():
result_transcripts = ds2_model.infer_batch(
@ -171,10 +164,10 @@ def evaluate():
for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result)
num_ins += 1
print("%s (%d/?) = %f" % \
(error_rate_info, num_ins, error_sum / num_ins))
print("Final %s (%d/%d) = %f" % \
(error_rate_info, num_ins, num_ins, error_sum / num_ins))
print("Error rate [%s] (%d/?) = %f" %
(args.error_rate_type, num_ins, error_sum / num_ins))
print("Final error rate [%s] (%d/%d) = %f" %
(args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
def main():

@ -9,8 +9,7 @@ import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
from error_rate import wer
from error_rate import cer
from error_rate import wer, cer
import utils
parser = argparse.ArgumentParser(description=__doc__)
@ -117,8 +116,8 @@ parser.add_argument(
default='wer',
choices=['wer', 'cer'],
type=str,
help="There are total two error rate types including wer and cer. wer "
"represents for word error rate while cer for character error rate. "
help="Error rate type for evaluation. 'wer' for word error rate and 'cer' "
"for character error rate. "
"(default: %(default)s)")
args = parser.parse_args()
@ -156,13 +155,7 @@ def infer():
language_model_path=args.language_model_path,
num_processes=args.num_processes_beam_search)
if args.error_rate_type == 'wer':
error_rate_func = wer
error_rate_info = 'wer'
else:
error_rate_func = cer
error_rate_info = 'cer'
error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
@ -170,8 +163,8 @@ def infer():
for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
print("Current %s = %f" % \
(error_rate_info, error_rate_func(target, result)))
print("Current error rate [%s] = %f" %
(args.error_rate_type, error_rate_func(target, result)))
def main():

Loading…
Cancel
Save