Make type of error rate optional.

pull/2/head
yangyaming 7 years ago
parent 7e39debcb0
commit 5ef300f3f0

@ -10,6 +10,7 @@ import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
from error_rate import wer
from error_rate import cer
import utils
parser = argparse.ArgumentParser(description=__doc__)
@ -111,6 +112,14 @@ parser.add_argument(
default='datasets/vocab/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--error_rate_type",
default='wer',
choices=['wer', 'cer'],
type=str,
help="There are total two error rate types including wer and cer. wer "
"represents for word error rate while cer for character error rate. "
"(default: %(default)s)")
args = parser.parse_args()
@ -136,7 +145,14 @@ def evaluate():
rnn_layer_size=args.rnn_layer_size,
pretrained_model_path=args.model_filepath)
wer_sum, num_ins = 0.0, 0
if args.error_rate_type == 'wer':
error_rate_func = wer
error_rate_info = 'WER'
else:
error_rate_func = cer
error_rate_info = 'CER'
error_sum, num_ins = 0.0, 0
for infer_data in batch_reader():
result_transcripts = ds2_model.infer_batch(
infer_data=infer_data,
@ -153,10 +169,12 @@ def evaluate():
for _, transcript in infer_data
]
for target, result in zip(target_transcripts, result_transcripts):
wer_sum += wer(target, result)
error_sum += error_rate_func(target, result)
num_ins += 1
print("WER (%d/?) = %f" % (num_ins, wer_sum / num_ins))
print("Final WER (%d/%d) = %f" % (num_ins, num_ins, wer_sum / num_ins))
print("%s (%d/?) = %f" % \
(error_rate_info, num_ins, error_sum / num_ins))
print("Final %s (%d/%d) = %f" % \
(error_rate_info, num_ins, num_ins, error_sum / num_ins))
def main():

@ -10,6 +10,7 @@ import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import DeepSpeech2Model
from error_rate import wer
from error_rate import cer
import utils
parser = argparse.ArgumentParser(description=__doc__)
@ -111,6 +112,14 @@ parser.add_argument(
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
parser.add_argument(
"--error_rate_type",
default='wer',
choices=['wer', 'cer'],
type=str,
help="There are total two error rate types including wer and cer. wer "
"represents for word error rate while cer for character error rate. "
"(default: %(default)s)")
args = parser.parse_args()
@ -147,6 +156,13 @@ def infer():
language_model_path=args.language_model_path,
num_processes=args.num_processes_beam_search)
if args.error_rate_type == 'wer':
error_rate_func = wer
error_rate_info = 'wer'
else:
error_rate_func = cer
error_rate_info = 'cer'
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
@ -154,7 +170,8 @@ def infer():
for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
print("Current wer = %f" % wer(target, result))
print("Current %s = %f" % \
(error_rate_info, error_rate_func(target, result)))
def main():

@ -185,7 +185,7 @@ class DeepSpeech2Model(object):
# best path decode
for i, probs in enumerate(probs_split):
output_transcription = ctc_best_path_decoder(
probs_seq=probs, vocabulary=data_generator.vocab_list)
probs_seq=probs, vocabulary=vocab_list)
results.append(output_transcription)
elif decode_method == "beam_search":
# initialize external scorer

Loading…
Cancel
Save