@ -18,7 +18,7 @@ import time
parser = argparse . ArgumentParser ( description = __doc__ )
parser = argparse . ArgumentParser ( description = __doc__ )
parser . add_argument (
parser . add_argument (
" --num_samples " ,
" --num_samples " ,
default = 10 ,
default = 32 ,
type = int ,
type = int ,
help = " Number of samples for inference. (default: %(default)s ) " )
help = " Number of samples for inference. (default: %(default)s ) " )
parser . add_argument (
parser . add_argument (
@ -46,6 +46,11 @@ parser.add_argument(
default = multiprocessing . cpu_count ( ) ,
default = multiprocessing . cpu_count ( ) ,
type = int ,
type = int ,
help = " Number of cpu threads for preprocessing data. (default: %(default)s ) " )
help = " Number of cpu threads for preprocessing data. (default: %(default)s ) " )
parser . add_argument (
" --num_processes_beam_search " ,
default = multiprocessing . cpu_count ( ) ,
type = int ,
help = " Number of cpu processes for beam search. (default: %(default)s ) " )
parser . add_argument (
parser . add_argument (
" --mean_std_filepath " ,
" --mean_std_filepath " ,
default = ' mean_std.npz ' ,
default = ' mean_std.npz ' ,
@ -70,8 +75,8 @@ parser.add_argument(
" --decode_method " ,
" --decode_method " ,
default = ' beam_search ' ,
default = ' beam_search ' ,
type = str ,
type = str ,
help = " Method for ctc decoding: be st_path or beam_search. (default: %(default)s ) "
help = " Method for ctc decoding: be am_search or beam_search_batch. "
)
" (default: %(default)s ) " )
parser . add_argument (
parser . add_argument (
" --beam_size " ,
" --beam_size " ,
default = 200 ,
default = 200 ,
@ -169,15 +174,28 @@ def infer():
## decode and print
## decode and print
time_begin = time . time ( )
time_begin = time . time ( )
wer_sum , wer_counter = 0 , 0
wer_sum , wer_counter = 0 , 0
for i , probs in enumerate ( probs_split ) :
batch_beam_results = [ ]
beam_result = ctc_beam_search_decoder (
if args . decode_method == ' beam_search ' :
probs_seq = probs ,
for i , probs in enumerate ( probs_split ) :
beam_result = ctc_beam_search_decoder (
probs_seq = probs ,
beam_size = args . beam_size ,
vocabulary = data_generator . vocab_list ,
blank_id = len ( data_generator . vocab_list ) ,
cutoff_prob = args . cutoff_prob ,
ext_scoring_func = ext_scorer , )
batch_beam_results + = [ beam_result ]
else :
batch_beam_results = ctc_beam_search_decoder_batch (
probs_split = probs_split ,
beam_size = args . beam_size ,
beam_size = args . beam_size ,
vocabulary = data_generator . vocab_list ,
vocabulary = data_generator . vocab_list ,
blank_id = len ( data_generator . vocab_list ) ,
blank_id = len ( data_generator . vocab_list ) ,
num_processes = args . num_processes_beam_search ,
cutoff_prob = args . cutoff_prob ,
cutoff_prob = args . cutoff_prob ,
ext_scoring_func = ext_scorer , )
ext_scoring_func = ext_scorer , )
for i , beam_result in enumerate ( batch_beam_results ) :
print ( " \n Target Transcription: \t %s " % target_transcription [ i ] )
print ( " \n Target Transcription: \t %s " % target_transcription [ i ] )
print ( " Beam %d : %f \t %s " % ( 0 , beam_result [ 0 ] [ 0 ] , beam_result [ 0 ] [ 1 ] ) )
print ( " Beam %d : %f \t %s " % ( 0 , beam_result [ 0 ] [ 0 ] , beam_result [ 0 ] [ 1 ] ) )
wer_cur = wer ( target_transcription [ i ] , beam_result [ 0 ] [ 1 ] )
wer_cur = wer ( target_transcription [ i ] , beam_result [ 0 ] [ 1 ] )
@ -185,6 +203,7 @@ def infer():
wer_counter + = 1
wer_counter + = 1
print ( " cur wer = %f , average wer = %f " %
print ( " cur wer = %f , average wer = %f " %
( wer_cur , wer_sum / wer_counter ) )
( wer_cur , wer_sum / wer_counter ) )
time_end = time . time ( )
time_end = time . time ( )
print ( " total time = %f " % ( time_end - time_begin ) )
print ( " total time = %f " % ( time_end - time_begin ) )