|
|
@ -9,8 +9,7 @@ import multiprocessing
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
from data_utils.data import DataGenerator
|
|
|
|
from data_utils.data import DataGenerator
|
|
|
|
from model import DeepSpeech2Model
|
|
|
|
from model import DeepSpeech2Model
|
|
|
|
from error_rate import wer
|
|
|
|
from error_rate import wer, cer
|
|
|
|
from error_rate import cer
|
|
|
|
|
|
|
|
import utils
|
|
|
|
import utils
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
@ -117,8 +116,8 @@ parser.add_argument(
|
|
|
|
default='wer',
|
|
|
|
default='wer',
|
|
|
|
choices=['wer', 'cer'],
|
|
|
|
choices=['wer', 'cer'],
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
help="There are total two error rate types including wer and cer. wer "
|
|
|
|
help="Error rate type for evaluation. 'wer' for word error rate and 'cer' "
|
|
|
|
"represents for word error rate while cer for character error rate. "
|
|
|
|
"for character error rate. "
|
|
|
|
"(default: %(default)s)")
|
|
|
|
"(default: %(default)s)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
@ -145,13 +144,7 @@ def evaluate():
|
|
|
|
rnn_layer_size=args.rnn_layer_size,
|
|
|
|
rnn_layer_size=args.rnn_layer_size,
|
|
|
|
pretrained_model_path=args.model_filepath)
|
|
|
|
pretrained_model_path=args.model_filepath)
|
|
|
|
|
|
|
|
|
|
|
|
if args.error_rate_type == 'wer':
|
|
|
|
error_rate_func = cer if args.error_rate_type == 'cer' else wer
|
|
|
|
error_rate_func = wer
|
|
|
|
|
|
|
|
error_rate_info = 'WER'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
error_rate_func = cer
|
|
|
|
|
|
|
|
error_rate_info = 'CER'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_sum, num_ins = 0.0, 0
|
|
|
|
error_sum, num_ins = 0.0, 0
|
|
|
|
for infer_data in batch_reader():
|
|
|
|
for infer_data in batch_reader():
|
|
|
|
result_transcripts = ds2_model.infer_batch(
|
|
|
|
result_transcripts = ds2_model.infer_batch(
|
|
|
@ -171,10 +164,10 @@ def evaluate():
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
error_sum += error_rate_func(target, result)
|
|
|
|
error_sum += error_rate_func(target, result)
|
|
|
|
num_ins += 1
|
|
|
|
num_ins += 1
|
|
|
|
print("%s (%d/?) = %f" % \
|
|
|
|
print("Error rate [%s] (%d/?) = %f" %
|
|
|
|
(error_rate_info, num_ins, error_sum / num_ins))
|
|
|
|
(args.error_rate_type, num_ins, error_sum / num_ins))
|
|
|
|
print("Final %s (%d/%d) = %f" % \
|
|
|
|
print("Final error rate [%s] (%d/%d) = %f" %
|
|
|
|
(error_rate_info, num_ins, num_ins, error_sum / num_ins))
|
|
|
|
(args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
def main():
|
|
|
|