From d2e467385d8367ac072a7d98688466d74661cc4b Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 5 Jun 2017 21:00:15 +0800 Subject: [PATCH] Add loading model function for train.py. --- train.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index e6a7d076..14c7cf63 100644 --- a/train.py +++ b/train.py @@ -11,6 +11,7 @@ import sys from model import deep_speech2 from audio_data_utils import DataGenerator import numpy as np +import os #TODO: add WER metric @@ -78,6 +79,11 @@ parser.add_argument( default='data/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--init_model_path", + default='models/params.tar.gz', + type=str, + help="Model path for initialization. (default: %(default)s)") args = parser.parse_args() @@ -114,8 +120,13 @@ def train(): rnn_size=args.rnn_layer_size, is_inference=False) - # create parameters and optimizer - parameters = paddle.parameters.create(cost) + # create/load parameters and optimizer + if args.init_model_path is None: + parameters = paddle.parameters.create(cost) + else: + assert os.path.isfile(args.init_model_path), "Invalid model." + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.init_model_path)) optimizer = paddle.optimizer.Adam( learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400) trainer = paddle.trainer.SGD(