|
|
|
@ -11,6 +11,7 @@ import sys
|
|
|
|
|
from model import deep_speech2
|
|
|
|
|
from audio_data_utils import DataGenerator
|
|
|
|
|
import numpy as np
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
#TODO: add WER metric
|
|
|
|
|
|
|
|
|
@ -78,6 +79,11 @@ parser.add_argument(
|
|
|
|
|
default='data/eng_vocab.txt',
|
|
|
|
|
type=str,
|
|
|
|
|
help="Vocabulary filepath. (default: %(default)s)")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--init_model_path",
|
|
|
|
|
default='models/params.tar.gz',
|
|
|
|
|
type=str,
|
|
|
|
|
help="Model path for initialization. (default: %(default)s)")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -114,8 +120,13 @@ def train():
|
|
|
|
|
rnn_size=args.rnn_layer_size,
|
|
|
|
|
is_inference=False)
|
|
|
|
|
|
|
|
|
|
# create parameters and optimizer
|
|
|
|
|
parameters = paddle.parameters.create(cost)
|
|
|
|
|
# create/load parameters and optimizer
|
|
|
|
|
if args.init_model_path is None:
|
|
|
|
|
parameters = paddle.parameters.create(cost)
|
|
|
|
|
else:
|
|
|
|
|
assert os.path.isfile(args.init_model_path), "Invalid model."
|
|
|
|
|
parameters = paddle.parameters.Parameters.from_tar(
|
|
|
|
|
gzip.open(args.init_model_path))
|
|
|
|
|
optimizer = paddle.optimizer.Adam(
|
|
|
|
|
learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
|
|
|
|
|
trainer = paddle.trainer.SGD(
|
|
|
|
|