|
|
|
@ -9,6 +9,7 @@ import gzip
|
|
|
|
|
from audio_data_utils import DataGenerator
|
|
|
|
|
from model import deep_speech2
|
|
|
|
|
from decoder import *
|
|
|
|
|
from error_rate import wer
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
description='Simplified version of DeepSpeech2 inference.')
|
|
|
|
@ -59,9 +60,9 @@ parser.add_argument(
|
|
|
|
|
help="Vocabulary filepath. (default: %(default)s)")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--decode_method",
|
|
|
|
|
default='beam_search',
|
|
|
|
|
default='beam_search_nproc',
|
|
|
|
|
type=str,
|
|
|
|
|
help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
|
|
|
|
|
help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--beam_size",
|
|
|
|
@ -151,6 +152,7 @@ def infer():
|
|
|
|
|
|
|
|
|
|
## decode and print
|
|
|
|
|
# best path decode
|
|
|
|
|
wer_sum, wer_counter = 0, 0
|
|
|
|
|
if args.decode_method == "best_path":
|
|
|
|
|
for i, probs in enumerate(probs_split):
|
|
|
|
|
target_transcription = ''.join(
|
|
|
|
@ -159,12 +161,17 @@ def infer():
|
|
|
|
|
probs_seq=probs, vocabulary=vocab_list)
|
|
|
|
|
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
|
|
|
|
(target_transcription, best_path_transcription))
|
|
|
|
|
wer_cur = wer(target_transcription, best_path_transcription)
|
|
|
|
|
wer_sum += wer_cur
|
|
|
|
|
wer_counter += 1
|
|
|
|
|
print("cur wer = %f, average wer = %f" %
|
|
|
|
|
(wer_cur, wer_sum / wer_counter))
|
|
|
|
|
# beam search decode
|
|
|
|
|
elif args.decode_method == "beam_search":
|
|
|
|
|
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
|
|
|
|
|
for i, probs in enumerate(probs_split):
|
|
|
|
|
target_transcription = ''.join(
|
|
|
|
|
[vocab_list[index] for index in infer_data[i][1]])
|
|
|
|
|
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
|
|
|
|
|
beam_search_result = ctc_beam_search_decoder(
|
|
|
|
|
probs_seq=probs,
|
|
|
|
|
vocabulary=vocab_list,
|
|
|
|
@ -172,10 +179,40 @@ def infer():
|
|
|
|
|
ext_scoring_func=ext_scorer.evaluate,
|
|
|
|
|
blank_id=len(vocab_list))
|
|
|
|
|
print("\nTarget Transcription:\t%s" % target_transcription)
|
|
|
|
|
|
|
|
|
|
for index in range(args.num_results_per_sample):
|
|
|
|
|
result = beam_search_result[index]
|
|
|
|
|
#output: index, log prob, beam result
|
|
|
|
|
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
|
|
|
|
|
wer_cur = wer(target_transcription, beam_search_result[0][1])
|
|
|
|
|
wer_sum += wer_cur
|
|
|
|
|
wer_counter += 1
|
|
|
|
|
print("cur wer = %f , average wer = %f" %
|
|
|
|
|
(wer_cur, wer_sum / wer_counter))
|
|
|
|
|
# beam search in multiple processes
|
|
|
|
|
elif args.decode_method == "beam_search_nproc":
|
|
|
|
|
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
|
|
|
|
|
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
|
|
|
|
|
probs_split=probs_split,
|
|
|
|
|
vocabulary=vocab_list,
|
|
|
|
|
beam_size=args.beam_size,
|
|
|
|
|
#ext_scoring_func=ext_scorer.evaluate,
|
|
|
|
|
ext_scoring_func=None,
|
|
|
|
|
blank_id=len(vocab_list))
|
|
|
|
|
for i, beam_search_result in enumerate(beam_search_nproc_results):
|
|
|
|
|
target_transcription = ''.join(
|
|
|
|
|
[vocab_list[index] for index in infer_data[i][1]])
|
|
|
|
|
print("\nTarget Transcription:\t%s" % target_transcription)
|
|
|
|
|
|
|
|
|
|
for index in range(args.num_results_per_sample):
|
|
|
|
|
result = beam_search_result[index]
|
|
|
|
|
#output: index, log prob, beam result
|
|
|
|
|
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
|
|
|
|
|
wer_cur = wer(target_transcription, beam_search_result[0][1])
|
|
|
|
|
wer_sum += wer_cur
|
|
|
|
|
wer_counter += 1
|
|
|
|
|
print("cur wer = %f , average wer = %f" %
|
|
|
|
|
(wer_cur, wer_sum / wer_counter))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Decoding method [%s] is not supported." % method)
|
|
|
|
|
|
|
|
|
|