test process can run

pull/522/head
Hui Zhang 5 years ago
parent c18162ca90
commit f6eafe85f1

@ -37,3 +37,15 @@ training:
plot_interval: 1000
save_interval: 1000
valid_interval: 1000
decoding:
alpha: 2.5
batch_size: 128
beam_size: 500
beta: 0.3
cutoff_prob: 1.0
cutoff_top_n: 40
decoding_method: ctc_beam_search
error_rate_type: wer
lang_model_path: models/lm/common_crawl_00.prune01111.trie.klm
num_proc_bsearch: 8

@ -10,29 +10,37 @@ cd - > /dev/null
# evaluate model
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python3 -u $MAIN_ROOT/test.py \
--batch_size=128 \
--beam_size=500 \
--num_proc_bsearch=8 \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=2.5 \
--beta=0.3 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
--test_manifest="data/manifest.test-clean" \
--mean_std_path="data/mean_std.npz" \
--vocab_path="data/vocab.txt" \
--model_path="checkpoints/step_final" \
--lang_model_path="$MAIN_ROOT/models/lm/common_crawl_00.prune01111.trie.klm" \
--decoding_method="ctc_beam_search" \
--error_rate_type="wer" \
--specgram_type="linear"
#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
#python3 -u $MAIN_ROOT/test.py \
#--batch_size=128 \
#--beam_size=500 \
#--num_proc_bsearch=8 \
#--num_conv_layers=2 \
#--num_rnn_layers=3 \
#--rnn_layer_size=2048 \
#--alpha=2.5 \
#--beta=0.3 \
#--cutoff_prob=1.0 \
#--cutoff_top_n=40 \
#--use_gru=False \
#--use_gpu=True \
#--share_rnn_weights=True \
#--test_manifest="data/manifest.test-clean" \
#--mean_std_path="data/mean_std.npz" \
#--vocab_path="data/vocab.txt" \
#--model_path="checkpoints/step_final" \
#--lang_model_path="$MAIN_ROOT/models/lm/common_crawl_00.prune01111.trie.klm" \
#--decoding_method="ctc_beam_search" \
#--error_rate_type="wer" \
#--specgram_type="linear"
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python3 -u ${MAIN_ROOT}/test.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--output ckpt
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"

@ -0,0 +1 @@
../../models/

@ -88,7 +88,7 @@ class DeepSpeech2Trainer(Trainer):
for k, v in losses_np.items())
self.logger.info(msg)
if dist.get_rank() == 0:
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
@ -143,8 +143,10 @@ class DeepSpeech2Trainer(Trainer):
for k, v in valid_losses.items())
self.logger.info(msg)
for k, v in valid_losses.items():
self.visualizer.add_scalar("valid/{}".format(k), v, self.iteration)
if self.visualizer:
for k, v in valid_losses.items():
self.visualizer.add_scalar("valid/{}".format(k), v,
self.iteration)
def setup_model(self):
config = self.config
@ -289,14 +291,37 @@ class DeepSpeech2Tester(Trainer):
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items())
self.logger.info(msg)
for k, v in losses.items():
self.visualizer.add_scalar("test/{}".format(k), v, self.iteration)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
if self.parallel:
self.init_parallel()
self.setup_output_dir()
self.setup_logger()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def run_test(self):
self.resume_or_load()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
finally:
self.destory()
def setup_model(self):
config = self.config
model = DeepSpeech2(
feat_size=self.train_loader.dataset.feature_size,
dict_size=self.train_loader.dataset.vocab_size,
feat_size=self.test_loader.dataset.feature_size,
dict_size=self.test_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
@ -305,7 +330,7 @@ class DeepSpeech2Tester(Trainer):
if self.parallel:
model = paddle.DataParallel(model)
criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size)
criterion = DeepSpeech2Loss(self.test_loader.dataset.vocab_size)
self.model = model
self.criterion = criterion
@ -331,6 +356,7 @@ class DeepSpeech2Tester(Trainer):
random_seed=config.data.random_seed,
keep_transcription_text=False)
collate_fn = SpeechCollator()
self.test_loader = DataLoader(
test_dataset,
batch_size=config.data.batch_size,

@ -24,58 +24,57 @@ from utils.utility import print_arguments
from training.cli import default_argument_parser
from model_utils.config import get_cfg_defaults
from model_utils.model import DeepSpeech2Trainer as Trainer
from model_utils.model import DeepSpeech2Tester as Tester
from utils.error_rate import char_errors, word_errors
# def evaluate():
# """Evaluate on whole test data for DeepSpeech2."""
def evaluate():
"""Evaluate on whole test data for DeepSpeech2."""
# # decoders only accept string encoded in utf-8
# vocab_list = [chars for chars in data_generator.vocab_list]
# decoders only accept string encoded in utf-8
vocab_list = [chars for chars in data_generator.vocab_list]
# errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
# errors_sum, len_refs, num_ins = 0.0, 0, 0
# ds2_model.logger.info("start evaluation ...")
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
errors_sum, len_refs, num_ins = 0.0, 0, 0
ds2_model.logger.info("start evaluation ...")
# for infer_data in batch_reader():
# probs_split = ds2_model.infer_batch_probs(
# infer_data=infer_data, feeding_dict=data_generator.feeding)
for infer_data in batch_reader():
probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data, feeding_dict=data_generator.feeding)
# if args.decoding_method == "ctc_greedy":
# result_transcripts = ds2_model.decode_batch_greedy(
# probs_split=probs_split, vocab_list=vocab_list)
# else:
# result_transcripts = ds2_model.decode_batch_beam_search(
# probs_split=probs_split,
# beam_alpha=args.alpha,
# beam_beta=args.beta,
# beam_size=args.beam_size,
# cutoff_prob=args.cutoff_prob,
# cutoff_top_n=args.cutoff_top_n,
# vocab_list=vocab_list,
# num_processes=args.num_proc_bsearch)
if args.decoding_method == "ctc_greedy":
result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split, vocab_list=vocab_list)
else:
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
num_processes=args.num_proc_bsearch)
# target_transcripts = infer_data[1]
target_transcripts = infer_data[1]
# for target, result in zip(target_transcripts, result_transcripts):
# errors, len_ref = errors_func(target, result)
# errors_sum += errors
# len_refs += len_ref
# num_ins += 1
# print("Error rate [%s] (%d/?) = %f" %
# (args.error_rate_type, num_ins, errors_sum / len_refs))
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
print("Error rate [%s] (%d/?) = %f" %
(args.error_rate_type, num_ins, errors_sum / len_refs))
# print("Final error rate [%s] (%d/%d) = %f" %
# (args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
print("Final error rate [%s] (%d/%d) = %f" %
(args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
ds2_model.logger.info("finish evaluation")
# ds2_model.logger.info("finish evaluation")
def main_sp(config, args):
exp = Trainer(config, args)
exp = Tester(config, args)
exp.setup()
exp.run()
exp.run_test()
def main(config, args):

@ -89,6 +89,8 @@ class Trainer():
def __init__(self, config, args):
self.config = config
self.args = args
self.optimizer = None
self.visualizer = None
def setup(self):
"""Setup the experiment.
@ -217,7 +219,8 @@ class Trainer():
@mp_tools.rank_zero_only
def destory(self):
# https://github.com/pytorch/fairseq/issues/2357
self.visualizer.close()
if self.visualizer:
self.visualizer.close()
@mp_tools.rank_zero_only
def setup_visualizer(self):

Loading…
Cancel
Save