Fix an incorrect usage of is_local argument.

pull/2/head
Xinghai Sun 7 years ago
parent c767f201b2
commit 75719fea22

@ -7,7 +7,7 @@ MEAN_STD_FILE="../mean_std.npz"
CLOUD_DATA_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/data"
CLOUD_MODEL_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/model"
# Configure cloud resources
NUM_CPU=12
NUM_CPU=8
NUM_GPU=8
NUM_NODE=1
MEMORY="10Gi"

@ -46,6 +46,7 @@ class DeepSpeech2Model(object):
gradient_clipping,
num_passes,
output_model_dir,
is_local=True,
num_iterations_print=100):
"""Train the model.
@ -65,6 +66,8 @@ class DeepSpeech2Model(object):
:param num_iterations_print: Number of training iterations for printing
a training loss.
:type rnn_iteratons_print: int
:param is_local: Set to False if running with pserver with multi-nodes.
:type is_local: bool
:param output_model_dir: Directory for saving the model (every pass).
:type output_model_dir: basestring
"""
@ -79,7 +82,8 @@ class DeepSpeech2Model(object):
trainer = paddle.trainer.SGD(
cost=self._loss,
parameters=self._parameters,
update_equation=optimizer)
update_equation=optimizer,
is_local=is_local)
# create event handler
def event_handler(event):

@ -179,15 +179,13 @@ def train():
gradient_clipping=400,
num_passes=args.num_passes,
num_iterations_print=args.num_iterations_print,
output_model_dir=args.output_model_dir)
output_model_dir=args.output_model_dir,
is_local=args.is_local)
def main():
utils.print_arguments(args)
paddle.init(
use_gpu=args.use_gpu,
trainer_count=args.trainer_count,
is_local=args.is_local)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
train()

Loading…
Cancel
Save