diff --git a/train.py b/train.py index 14c7cf63..89ab23c6 100644 --- a/train.py +++ b/train.py @@ -81,9 +81,11 @@ parser.add_argument( help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( "--init_model_path", - default='models/params.tar.gz', + default=None, type=str, - help="Model path for initialization. (default: %(default)s)") + help="If set None, the training will start from scratch. " + "Otherwise, the training will resume from " + "the existing model of this path. (default: %(default)s)") args = parser.parse_args() @@ -124,7 +126,8 @@ def train(): if args.init_model_path is None: parameters = paddle.parameters.create(cost) else: - assert os.path.isfile(args.init_model_path), "Invalid model." + if not os.path.isfile(args.init_model_path): + raise IOError("Invalid model!") parameters = paddle.parameters.Parameters.from_tar( gzip.open(args.init_model_path)) optimizer = paddle.optimizer.Adam(