|
|
|
@ -17,15 +17,15 @@ add_arg = functools.partial(add_arguments, argparser=parser)
|
|
|
|
|
add_arg('batch_size', int, 128, "Minibatch size.")
|
|
|
|
|
add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).")
|
|
|
|
|
add_arg('beam_size', int, 500, "Beam search width.")
|
|
|
|
|
add_arg('parallels_bsearch',int, 12, "# of CPUs for beam search.")
|
|
|
|
|
add_arg('parallels_data', int, 12, "# of CPUs for data preprocessing.")
|
|
|
|
|
add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.")
|
|
|
|
|
add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.")
|
|
|
|
|
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
|
|
|
|
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
|
|
|
|
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
|
|
|
|
add_arg('alpha', float, 0.36, "Coef of LM for beam search.")
|
|
|
|
|
add_arg('beta', float, 0.25, "Coef of WC for beam search.")
|
|
|
|
|
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.")
|
|
|
|
|
add_arg('use_gru', bool, False, "Use GRUs instead of Simple RNNs.")
|
|
|
|
|
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
|
|
|
|
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
|
|
|
|
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
|
|
|
|
"bi-directional RNNs. Not for GRU.")
|
|
|
|
@ -45,9 +45,9 @@ add_arg('model_path', str,
|
|
|
|
|
add_arg('lang_model_path', str,
|
|
|
|
|
'lm/data/common_crawl_00.prune01111.trie.klm',
|
|
|
|
|
"Filepath for language model.")
|
|
|
|
|
add_arg('decoder_method', str,
|
|
|
|
|
add_arg('decoding_method', str,
|
|
|
|
|
'ctc_beam_search',
|
|
|
|
|
"Decoder method. Options: ctc_beam_search, ctc_greedy",
|
|
|
|
|
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
|
|
|
|
choices = ['ctc_beam_search', 'ctc_greedy'])
|
|
|
|
|
add_arg('error_rate_type', str,
|
|
|
|
|
'wer',
|
|
|
|
@ -68,7 +68,7 @@ def evaluate():
|
|
|
|
|
mean_std_filepath=args.mean_std_path,
|
|
|
|
|
augmentation_config='{}',
|
|
|
|
|
specgram_type=args.specgram_type,
|
|
|
|
|
num_threads=args.parallels_data)
|
|
|
|
|
num_threads=args.num_proc_data)
|
|
|
|
|
batch_reader = data_generator.batch_reader_creator(
|
|
|
|
|
manifest_path=args.test_manifest,
|
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
@ -90,14 +90,14 @@ def evaluate():
|
|
|
|
|
for infer_data in batch_reader():
|
|
|
|
|
result_transcripts = ds2_model.infer_batch(
|
|
|
|
|
infer_data=infer_data,
|
|
|
|
|
decoder_method=args.decoder_method,
|
|
|
|
|
decoding_method=args.decoding_method,
|
|
|
|
|
beam_alpha=args.alpha,
|
|
|
|
|
beam_beta=args.beta,
|
|
|
|
|
beam_size=args.beam_size,
|
|
|
|
|
cutoff_prob=args.cutoff_prob,
|
|
|
|
|
vocab_list=data_generator.vocab_list,
|
|
|
|
|
language_model_path=args.lang_model_path,
|
|
|
|
|
num_processes=args.parallels_bsearch)
|
|
|
|
|
num_processes=args.num_proc_bsearch)
|
|
|
|
|
target_transcripts = [
|
|
|
|
|
''.join([data_generator.vocab_list[token] for token in transcript])
|
|
|
|
|
for _, transcript in infer_data
|
|
|
|
|