|
|
@ -10,6 +10,7 @@ import paddle.v2 as paddle
|
|
|
|
from data_utils.data import DataGenerator
|
|
|
|
from data_utils.data import DataGenerator
|
|
|
|
from model import DeepSpeech2Model
|
|
|
|
from model import DeepSpeech2Model
|
|
|
|
from error_rate import wer
|
|
|
|
from error_rate import wer
|
|
|
|
|
|
|
|
from error_rate import cer
|
|
|
|
import utils
|
|
|
|
import utils
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
@ -111,6 +112,14 @@ parser.add_argument(
|
|
|
|
default='datasets/vocab/eng_vocab.txt',
|
|
|
|
default='datasets/vocab/eng_vocab.txt',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
help="Vocabulary filepath. (default: %(default)s)")
|
|
|
|
help="Vocabulary filepath. (default: %(default)s)")
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
|
|
"--error_rate_type",
|
|
|
|
|
|
|
|
default='wer',
|
|
|
|
|
|
|
|
choices=['wer', 'cer'],
|
|
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
|
|
help="There are total two error rate types including wer and cer. wer "
|
|
|
|
|
|
|
|
"represents for word error rate while cer for character error rate. "
|
|
|
|
|
|
|
|
"(default: %(default)s)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -136,7 +145,14 @@ def evaluate():
|
|
|
|
rnn_layer_size=args.rnn_layer_size,
|
|
|
|
rnn_layer_size=args.rnn_layer_size,
|
|
|
|
pretrained_model_path=args.model_filepath)
|
|
|
|
pretrained_model_path=args.model_filepath)
|
|
|
|
|
|
|
|
|
|
|
|
wer_sum, num_ins = 0.0, 0
|
|
|
|
if args.error_rate_type == 'wer':
|
|
|
|
|
|
|
|
error_rate_func = wer
|
|
|
|
|
|
|
|
error_rate_info = 'WER'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
error_rate_func = cer
|
|
|
|
|
|
|
|
error_rate_info = 'CER'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_sum, num_ins = 0.0, 0
|
|
|
|
for infer_data in batch_reader():
|
|
|
|
for infer_data in batch_reader():
|
|
|
|
result_transcripts = ds2_model.infer_batch(
|
|
|
|
result_transcripts = ds2_model.infer_batch(
|
|
|
|
infer_data=infer_data,
|
|
|
|
infer_data=infer_data,
|
|
|
@ -153,10 +169,12 @@ def evaluate():
|
|
|
|
for _, transcript in infer_data
|
|
|
|
for _, transcript in infer_data
|
|
|
|
]
|
|
|
|
]
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
wer_sum += wer(target, result)
|
|
|
|
error_sum += error_rate_func(target, result)
|
|
|
|
num_ins += 1
|
|
|
|
num_ins += 1
|
|
|
|
print("WER (%d/?) = %f" % (num_ins, wer_sum / num_ins))
|
|
|
|
print("%s (%d/?) = %f" % \
|
|
|
|
print("Final WER (%d/%d) = %f" % (num_ins, num_ins, wer_sum / num_ins))
|
|
|
|
(error_rate_info, num_ins, error_sum / num_ins))
|
|
|
|
|
|
|
|
print("Final %s (%d/%d) = %f" % \
|
|
|
|
|
|
|
|
(error_rate_info, num_ins, num_ins, error_sum / num_ins))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
def main():
|
|
|
|