|
|
@ -46,6 +46,7 @@ class DeepSpeech2Model(object):
|
|
|
|
gradient_clipping,
|
|
|
|
gradient_clipping,
|
|
|
|
num_passes,
|
|
|
|
num_passes,
|
|
|
|
output_model_dir,
|
|
|
|
output_model_dir,
|
|
|
|
|
|
|
|
is_local=True,
|
|
|
|
num_iterations_print=100):
|
|
|
|
num_iterations_print=100):
|
|
|
|
"""Train the model.
|
|
|
|
"""Train the model.
|
|
|
|
|
|
|
|
|
|
|
@ -65,6 +66,8 @@ class DeepSpeech2Model(object):
|
|
|
|
:param num_iterations_print: Number of training iterations for printing
|
|
|
|
:param num_iterations_print: Number of training iterations for printing
|
|
|
|
a training loss.
|
|
|
|
a training loss.
|
|
|
|
:type rnn_iteratons_print: int
|
|
|
|
:type rnn_iteratons_print: int
|
|
|
|
|
|
|
|
:param is_local: Set to False if running with pserver with multi-nodes.
|
|
|
|
|
|
|
|
:type is_local: bool
|
|
|
|
:param output_model_dir: Directory for saving the model (every pass).
|
|
|
|
:param output_model_dir: Directory for saving the model (every pass).
|
|
|
|
:type output_model_dir: basestring
|
|
|
|
:type output_model_dir: basestring
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -79,7 +82,8 @@ class DeepSpeech2Model(object):
|
|
|
|
trainer = paddle.trainer.SGD(
|
|
|
|
trainer = paddle.trainer.SGD(
|
|
|
|
cost=self._loss,
|
|
|
|
cost=self._loss,
|
|
|
|
parameters=self._parameters,
|
|
|
|
parameters=self._parameters,
|
|
|
|
update_equation=optimizer)
|
|
|
|
update_equation=optimizer,
|
|
|
|
|
|
|
|
is_local=is_local)
|
|
|
|
|
|
|
|
|
|
|
|
# create event handler
|
|
|
|
# create event handler
|
|
|
|
def event_handler(event):
|
|
|
|
def event_handler(event):
|
|
|
|