Add loading model function for train.py.

pull/2/head
yangyaming 8 years ago
parent 730d5c4dd3
commit d2e467385d

@ -11,6 +11,7 @@ import sys
from model import deep_speech2 from model import deep_speech2
from audio_data_utils import DataGenerator from audio_data_utils import DataGenerator
import numpy as np import numpy as np
import os
#TODO: add WER metric #TODO: add WER metric
@ -78,6 +79,11 @@ parser.add_argument(
default='data/eng_vocab.txt', default='data/eng_vocab.txt',
type=str, type=str,
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--init_model_path",
default='models/params.tar.gz',
type=str,
help="Model path for initialization. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
@ -114,8 +120,13 @@ def train():
rnn_size=args.rnn_layer_size, rnn_size=args.rnn_layer_size,
is_inference=False) is_inference=False)
# create parameters and optimizer # create/load parameters and optimizer
if args.init_model_path is None:
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
else:
assert os.path.isfile(args.init_model_path), "Invalid model."
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.init_model_path))
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400) learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
trainer = paddle.trainer.SGD( trainer = paddle.trainer.SGD(

Loading…
Cancel
Save