|
|
@ -24,58 +24,57 @@ from utils.utility import print_arguments
|
|
|
|
from training.cli import default_argument_parser
|
|
|
|
from training.cli import default_argument_parser
|
|
|
|
|
|
|
|
|
|
|
|
from model_utils.config import get_cfg_defaults
|
|
|
|
from model_utils.config import get_cfg_defaults
|
|
|
|
from model_utils.model import DeepSpeech2Trainer as Trainer
|
|
|
|
from model_utils.model import DeepSpeech2Tester as Tester
|
|
|
|
from utils.error_rate import char_errors, word_errors
|
|
|
|
from utils.error_rate import char_errors, word_errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def evaluate():
|
|
|
|
|
|
|
|
# """Evaluate on whole test data for DeepSpeech2."""
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate():
|
|
|
|
# # decoders only accept string encoded in utf-8
|
|
|
|
"""Evaluate on whole test data for DeepSpeech2."""
|
|
|
|
# vocab_list = [chars for chars in data_generator.vocab_list]
|
|
|
|
|
|
|
|
|
|
|
|
# decoders only accept string encoded in utf-8
|
|
|
|
# errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
|
|
|
|
vocab_list = [chars for chars in data_generator.vocab_list]
|
|
|
|
# errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
|
|
|
# ds2_model.logger.info("start evaluation ...")
|
|
|
|
|
|
|
|
|
|
|
|
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
|
|
|
|
# for infer_data in batch_reader():
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
# probs_split = ds2_model.infer_batch_probs(
|
|
|
|
ds2_model.logger.info("start evaluation ...")
|
|
|
|
# infer_data=infer_data, feeding_dict=data_generator.feeding)
|
|
|
|
|
|
|
|
|
|
|
|
for infer_data in batch_reader():
|
|
|
|
# if args.decoding_method == "ctc_greedy":
|
|
|
|
probs_split = ds2_model.infer_batch_probs(
|
|
|
|
# result_transcripts = ds2_model.decode_batch_greedy(
|
|
|
|
infer_data=infer_data, feeding_dict=data_generator.feeding)
|
|
|
|
# probs_split=probs_split, vocab_list=vocab_list)
|
|
|
|
|
|
|
|
# else:
|
|
|
|
|
|
|
|
# result_transcripts = ds2_model.decode_batch_beam_search(
|
|
|
|
|
|
|
|
# probs_split=probs_split,
|
|
|
|
|
|
|
|
# beam_alpha=args.alpha,
|
|
|
|
|
|
|
|
# beam_beta=args.beta,
|
|
|
|
|
|
|
|
# beam_size=args.beam_size,
|
|
|
|
|
|
|
|
# cutoff_prob=args.cutoff_prob,
|
|
|
|
|
|
|
|
# cutoff_top_n=args.cutoff_top_n,
|
|
|
|
|
|
|
|
# vocab_list=vocab_list,
|
|
|
|
|
|
|
|
# num_processes=args.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
|
|
if args.decoding_method == "ctc_greedy":
|
|
|
|
# target_transcripts = infer_data[1]
|
|
|
|
result_transcripts = ds2_model.decode_batch_greedy(
|
|
|
|
|
|
|
|
probs_split=probs_split, vocab_list=vocab_list)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
result_transcripts = ds2_model.decode_batch_beam_search(
|
|
|
|
|
|
|
|
probs_split=probs_split,
|
|
|
|
|
|
|
|
beam_alpha=args.alpha,
|
|
|
|
|
|
|
|
beam_beta=args.beta,
|
|
|
|
|
|
|
|
beam_size=args.beam_size,
|
|
|
|
|
|
|
|
cutoff_prob=args.cutoff_prob,
|
|
|
|
|
|
|
|
cutoff_top_n=args.cutoff_top_n,
|
|
|
|
|
|
|
|
vocab_list=vocab_list,
|
|
|
|
|
|
|
|
num_processes=args.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_transcripts = infer_data[1]
|
|
|
|
# for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
|
|
|
|
# errors, len_ref = errors_func(target, result)
|
|
|
|
|
|
|
|
# errors_sum += errors
|
|
|
|
|
|
|
|
# len_refs += len_ref
|
|
|
|
|
|
|
|
# num_ins += 1
|
|
|
|
|
|
|
|
# print("Error rate [%s] (%d/?) = %f" %
|
|
|
|
|
|
|
|
# (args.error_rate_type, num_ins, errors_sum / len_refs))
|
|
|
|
|
|
|
|
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
# print("Final error rate [%s] (%d/%d) = %f" %
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
|
# (args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
|
|
|
|
errors_sum += errors
|
|
|
|
|
|
|
|
len_refs += len_ref
|
|
|
|
|
|
|
|
num_ins += 1
|
|
|
|
|
|
|
|
print("Error rate [%s] (%d/?) = %f" %
|
|
|
|
|
|
|
|
(args.error_rate_type, num_ins, errors_sum / len_refs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Final error rate [%s] (%d/%d) = %f" %
|
|
|
|
# ds2_model.logger.info("finish evaluation")
|
|
|
|
(args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ds2_model.logger.info("finish evaluation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main_sp(config, args):
|
|
|
|
def main_sp(config, args):
|
|
|
|
exp = Trainer(config, args)
|
|
|
|
exp = Tester(config, args)
|
|
|
|
exp.setup()
|
|
|
|
exp.setup()
|
|
|
|
exp.run()
|
|
|
|
exp.run_test()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(config, args):
|
|
|
|
def main(config, args):
|
|
|
|