From 75719fea22677d46b44fce1aa0beb05dae377ccb Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Mon, 14 Aug 2017 20:21:09 +0800 Subject: [PATCH] Fix an incorrect usage of is_local argument. --- cloud/pcloud_submit.sh | 2 +- model.py | 6 +++++- train.py | 8 +++----- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/cloud/pcloud_submit.sh b/cloud/pcloud_submit.sh index 2fb80d66..3a64f32e 100644 --- a/cloud/pcloud_submit.sh +++ b/cloud/pcloud_submit.sh @@ -7,7 +7,7 @@ MEAN_STD_FILE="../mean_std.npz" CLOUD_DATA_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/data" CLOUD_MODEL_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/model" # Configure cloud resources -NUM_CPU=12 +NUM_CPU=8 NUM_GPU=8 NUM_NODE=1 MEMORY="10Gi" diff --git a/model.py b/model.py index e2f2903b..99412e59 100644 --- a/model.py +++ b/model.py @@ -46,6 +46,7 @@ class DeepSpeech2Model(object): gradient_clipping, num_passes, output_model_dir, + is_local=True, num_iterations_print=100): """Train the model. @@ -65,6 +66,8 @@ class DeepSpeech2Model(object): :param num_iterations_print: Number of training iterations for printing a training loss. :type rnn_iteratons_print: int + :param is_local: Set to False if running with pserver with multi-nodes. + :type is_local: bool :param output_model_dir: Directory for saving the model (every pass). :type output_model_dir: basestring """ @@ -79,7 +82,8 @@ class DeepSpeech2Model(object): trainer = paddle.trainer.SGD( cost=self._loss, parameters=self._parameters, - update_equation=optimizer) + update_equation=optimizer, + is_local=is_local) # create event handler def event_handler(event): diff --git a/train.py b/train.py index 379e364c..262d8bf0 100644 --- a/train.py +++ b/train.py @@ -179,15 +179,13 @@ def train(): gradient_clipping=400, num_passes=args.num_passes, num_iterations_print=args.num_iterations_print, - output_model_dir=args.output_model_dir) + output_model_dir=args.output_model_dir, + is_local=args.is_local) def main(): utils.print_arguments(args) - paddle.init( - use_gpu=args.use_gpu, - trainer_count=args.trainer_count, - is_local=args.is_local) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) train()