Fix an incorrect usage of is_local argument.

pull/2/head
Xinghai Sun 7 years ago
parent c767f201b2
commit 75719fea22

@ -7,7 +7,7 @@ MEAN_STD_FILE="../mean_std.npz"
CLOUD_DATA_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/data" CLOUD_DATA_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/data"
CLOUD_MODEL_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/model" CLOUD_MODEL_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/model"
# Configure cloud resources # Configure cloud resources
NUM_CPU=12 NUM_CPU=8
NUM_GPU=8 NUM_GPU=8
NUM_NODE=1 NUM_NODE=1
MEMORY="10Gi" MEMORY="10Gi"

@ -46,6 +46,7 @@ class DeepSpeech2Model(object):
gradient_clipping, gradient_clipping,
num_passes, num_passes,
output_model_dir, output_model_dir,
is_local=True,
num_iterations_print=100): num_iterations_print=100):
"""Train the model. """Train the model.
@ -65,6 +66,8 @@ class DeepSpeech2Model(object):
:param num_iterations_print: Number of training iterations for printing :param num_iterations_print: Number of training iterations for printing
a training loss. a training loss.
:type rnn_iteratons_print: int :type rnn_iteratons_print: int
:param is_local: Set to False if running with pserver with multi-nodes.
:type is_local: bool
:param output_model_dir: Directory for saving the model (every pass). :param output_model_dir: Directory for saving the model (every pass).
:type output_model_dir: basestring :type output_model_dir: basestring
""" """
@ -79,7 +82,8 @@ class DeepSpeech2Model(object):
trainer = paddle.trainer.SGD( trainer = paddle.trainer.SGD(
cost=self._loss, cost=self._loss,
parameters=self._parameters, parameters=self._parameters,
update_equation=optimizer) update_equation=optimizer,
is_local=is_local)
# create event handler # create event handler
def event_handler(event): def event_handler(event):

@ -179,15 +179,13 @@ def train():
gradient_clipping=400, gradient_clipping=400,
num_passes=args.num_passes, num_passes=args.num_passes,
num_iterations_print=args.num_iterations_print, num_iterations_print=args.num_iterations_print,
output_model_dir=args.output_model_dir) output_model_dir=args.output_model_dir,
is_local=args.is_local)
def main(): def main():
utils.print_arguments(args) utils.print_arguments(args)
paddle.init( paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
use_gpu=args.use_gpu,
trainer_count=args.trainer_count,
is_local=args.is_local)
train() train()

Loading…
Cancel
Save