Update KWS example.

pull/1783/head
KP 3 years ago
parent 2b44f374c1
commit abb15ac6e8

@ -1,39 +1,49 @@
data: # https://yaml.org/type/float.html
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter' ###########################################
dataset: 'paddleaudio.datasets:HeySnips' # Data #
###########################################
dataset: 'paddleaudio.datasets:HeySnips'
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
model: ############################################
num_keywords: 1 # Network Architecture #
backbone: 'paddlespeech.kws.models:MDTC' ############################################
config: backbone: 'paddlespeech.kws.models:MDTC'
stack_num: 3 num_keywords: 1
stack_size: 4 stack_num: 3
in_channels: 80 stack_size: 4
res_channels: 32 in_channels: 80
kernel_size: 5 res_channels: 32
kernel_size: 5
feature: ###########################################
feat_type: 'kaldi_fbank' # Feature #
sample_rate: 16000 ###########################################
frame_shift: 10 feat_type: 'kaldi_fbank'
frame_length: 25 sample_rate: 16000
n_mels: 80 frame_shift: 10
frame_length: 25
n_mels: 80
training: ###########################################
epochs: 100 # Training #
num_workers: 16 ###########################################
batch_size: 100 epochs: 100
checkpoint_dir: './checkpoint' num_workers: 16
save_freq: 10 batch_size: 100
log_freq: 10 checkpoint_dir: './checkpoint'
learning_rate: 0.001 save_freq: 10
weight_decay: 0.00005 log_freq: 10
grad_clip: 5.0 learning_rate: 0.001
weight_decay: 0.00005
grad_clip: 5.0
scoring: ###########################################
batch_size: 100 # Scoring #
num_workers: 16 ###########################################
checkpoint: './checkpoint/epoch_100/model.pdparams' batch_size: 100
score_file: './scores.txt' num_workers: 16
stats_file: './stats.0.txt' checkpoint: './checkpoint/epoch_100/model.pdparams'
img_file: './det.png' score_file: './scores.txt'
stats_file: './stats.0.txt'
img_file: './det.png'

@ -1,2 +1,25 @@
#!/bin/bash #!/bin/bash
python3 ${BIN_DIR}/plot_det_curve.py --cfg_path=$1 --keyword HeySnips # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 3 ];then
echo "usage: ${0} config_path checkpoint output_file"
exit -1
fi
keyword=$1
stats_file=$2
img_file=$3
python3 ${BIN_DIR}/plot_det_curve.py --keyword_label ${keyword} --stats_file ${stats_file} --img_file ${img_file}

@ -1,5 +1,27 @@
#!/bin/bash #!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python3 ${BIN_DIR}/score.py --cfg_path=$1 if [ $# != 4 ];then
echo "usage: ${0} checkpoint score_file stats_file"
exit -1
fi
python3 ${BIN_DIR}/compute_det.py --cfg_path=$1 cfg_path=$1
ckpt=$2
score_file=$3
stats_file=$4
python3 ${BIN_DIR}/score.py --config ${cfg_path} --ckpt ${ckpt} --score_file ${score_file} || exit -1
python3 ${BIN_DIR}/compute_det.py --config ${cfg_path} --score_file ${score_file} --stats_file ${stats_file} || exit -1

@ -1,13 +1,31 @@
#!/bin/bash #!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 2 ];then
echo "usage: ${0} num_gpus config_path"
exit -1
fi
ngpu=$1 ngpu=$1
cfg_path=$2 cfg_path=$2
if [ ${ngpu} -gt 0 ]; then if [ ${ngpu} -gt 0 ]; then
python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \
--cfg_path ${cfg_path} --config ${cfg_path}
else else
echo "set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning." echo "set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning."
python3 ${BIN_DIR}/train.py \ python3 ${BIN_DIR}/train.py \
--cfg_path ${cfg_path} --config ${cfg_path}
fi fi

@ -32,10 +32,16 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
./local/train.sh ${ngpu} ${cfg_path} || exit -1 ./local/train.sh ${ngpu} ${cfg_path} || exit -1
fi fi
ckpt=./checkpoint/epoch_100/model.pdparams
score_file=./scores.txt
stats_file=./stats.0.txt
img_file=./det.png
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/score.sh ${cfg_path} || exit -1 ./local/score.sh ${cfg_path} ${ckpt} ${score_file} ${stats_file} || exit -1
fi fi
keyword=HeySnips
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./local/plot.sh ${cfg_path} || exit -1 ./local/plot.sh ${keyword} ${stats_file} ${img_file} || exit -1
fi fi

@ -12,24 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws) # Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os import os
import paddle import paddle
import yaml
from tqdm import tqdm from tqdm import tqdm
from yacs.config import CfgNode
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
parser.add_argument('--keyword_index', type=int, default=0, help='keyword index')
parser.add_argument('--step', type=float, default=0.01, help='threshold step of trigger score')
parser.add_argument('--window_shift', type=int, default=50, help='window_shift is used to skip the frames after triggered')
args = parser.parse_args()
# yapf: enable
def load_label_and_score(keyword_index: int, def load_label_and_score(keyword_index: int,
ds: paddle.io.Dataset, ds: paddle.io.Dataset,
@ -61,26 +52,52 @@ def load_label_and_score(keyword_index: int,
if __name__ == '__main__': if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) parser = default_argument_parser()
with open(args.cfg_path, 'r') as f: parser.add_argument(
config = yaml.safe_load(f) '--keyword_index', type=int, default=0, help='keyword index')
parser.add_argument(
'--step',
type=float,
default=0.01,
help='threshold step of trigger score')
parser.add_argument(
'--window_shift',
type=int,
default=50,
help='window_shift is used to skip the frames after triggered')
parser.add_argument(
"--score_file",
type=str,
required=True,
help='output file of trigger scores')
parser.add_argument(
'--stats_file',
type=str,
default='./stats.0.txt',
help='output file of detection error tradeoff')
args = parser.parse_args()
data_conf = config['data'] # https://yaml.org/type/float.html
feat_conf = config['feature'] config = CfgNode(new_allowed=True)
scoring_conf = config['scoring'] if args.config:
config.merge_from_file(args.config)
# Dataset # Dataset
ds_class = dynamic_import(data_conf['dataset']) ds_class = dynamic_import(config['dataset'])
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf) test_ds = ds_class(
data_dir=config['data_dir'],
score_file = os.path.abspath(scoring_conf['score_file']) mode='test',
stats_file = os.path.abspath(scoring_conf['stats_file']) feat_type=config['feat_type'],
sample_rate=config['sample_rate'],
frame_shift=config['frame_shift'],
frame_length=config['frame_length'],
n_mels=config['n_mels'], )
keyword_table, filler_table, filler_duration = load_label_and_score( keyword_table, filler_table, filler_duration = load_label_and_score(
args.keyword, test_ds, score_file) args.keyword_index, test_ds, args.score_file)
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
pbar = tqdm(total=int(1.0 / args.step)) pbar = tqdm(total=int(1.0 / args.step))
with open(stats_file, 'w', encoding='utf8') as fout: with open(args.stats_file, 'w', encoding='utf8') as fout:
keyword_index = args.keyword_index keyword_index = args.keyword_index
threshold = 0.0 threshold = 0.0
while threshold <= 1.0: while threshold <= 1.0:
@ -113,4 +130,4 @@ if __name__ == '__main__':
pbar.update(1) pbar.update(1)
pbar.close() pbar.close()
print('DET saved to: {}'.format(stats_file)) print('DET saved to: {}'.format(args.stats_file))

@ -17,12 +17,12 @@ import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import yaml
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True) parser.add_argument('--keyword_label', type=str, required=True, help='keyword string shown on image')
parser.add_argument("--keyword", type=str, required=True) parser.add_argument('--stats_file', type=str, required=True, help='output file of detection error tradeoff')
parser.add_argument('--img_file', type=str, default='./det.png', help='output det image')
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
@ -61,14 +61,8 @@ def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim,
if __name__ == '__main__': if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) img_file = os.path.abspath(args.img_file)
with open(args.cfg_path, 'r') as f: stats_file = os.path.abspath(args.stats_file)
config = yaml.safe_load(f) plot_det_curve([args.keyword_label], stats_file, img_file, 10, 2, 10, 2)
scoring_conf = config['scoring']
img_file = os.path.abspath(scoring_conf['img_file'])
stats_file = os.path.abspath(scoring_conf['stats_file'])
keywords = [args.keyword]
plot_det_curve(keywords, stats_file, img_file, 10, 2, 10, 2)
print('DET curve image saved to: {}'.format(img_file)) print('DET curve image saved to: {}'.format(img_file))

@ -12,55 +12,67 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws) # Modified from wekws(https://github.com/wenet-e2e/wekws)
import argparse
import os
import paddle import paddle
import yaml
from tqdm import tqdm from tqdm import tqdm
from yacs.config import CfgNode
from paddlespeech.kws.exps.mdtc.collate import collate_features from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.mdtc import KWSModel from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
if __name__ == '__main__': if __name__ == '__main__':
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) parser = default_argument_parser()
with open(args.cfg_path, 'r') as f: parser.add_argument(
config = yaml.safe_load(f) "--ckpt",
type=str,
required=True,
help='model checkpoint for evaluation.')
parser.add_argument(
"--score_file",
type=str,
default='./scores.txt',
help='output file of trigger scores')
args = parser.parse_args()
model_conf = config['model'] # https://yaml.org/type/float.html
data_conf = config['data'] config = CfgNode(new_allowed=True)
feat_conf = config['feature'] if args.config:
scoring_conf = config['scoring'] config.merge_from_file(args.config)
# Dataset # Dataset
ds_class = dynamic_import(data_conf['dataset']) ds_class = dynamic_import(config['dataset'])
test_ds = ds_class(data_dir=data_conf['data_dir'], mode='test', **feat_conf) test_ds = ds_class(
data_dir=config['data_dir'],
mode='test',
feat_type=config['feat_type'],
sample_rate=config['sample_rate'],
frame_shift=config['frame_shift'],
frame_length=config['frame_length'],
n_mels=config['n_mels'], )
test_sampler = paddle.io.BatchSampler( test_sampler = paddle.io.BatchSampler(
test_ds, batch_size=scoring_conf['batch_size'], drop_last=False) test_ds, batch_size=config['batch_size'], drop_last=False)
test_loader = paddle.io.DataLoader( test_loader = paddle.io.DataLoader(
test_ds, test_ds,
batch_sampler=test_sampler, batch_sampler=test_sampler,
num_workers=scoring_conf['num_workers'], num_workers=config['num_workers'],
return_list=True, return_list=True,
use_buffer_reader=True, use_buffer_reader=True,
collate_fn=collate_features, ) collate_fn=collate_features, )
# Model # Model
backbone_class = dynamic_import(model_conf['backbone']) backbone_class = dynamic_import(config['backbone'])
backbone = backbone_class(**model_conf['config']) backbone = backbone_class(
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords']) stack_num=config['stack_num'],
model.set_state_dict(paddle.load(scoring_conf['checkpoint'])) stack_size=config['stack_size'],
in_channels=config['in_channels'],
res_channels=config['res_channels'],
kernel_size=config['kernel_size'], )
model = KWSModel(backbone=backbone, num_keywords=config['num_keywords'])
model.set_state_dict(paddle.load(args.ckpt))
model.eval() model.eval()
with paddle.no_grad(), open( with paddle.no_grad(), open(args.score_file, 'w', encoding='utf8') as f:
scoring_conf['score_file'], 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate( for batch_idx, batch in enumerate(
tqdm(test_loader, total=len(test_loader))): tqdm(test_loader, total=len(test_loader))):
keys, feats, labels, lengths = batch keys, feats, labels, lengths = batch
@ -73,7 +85,6 @@ if __name__ == '__main__':
keyword_scores = score[:, keyword_i] keyword_scores = score[:, keyword_i]
score_frames = ' '.join( score_frames = ' '.join(
['{:.6f}'.format(x) for x in keyword_scores.tolist()]) ['{:.6f}'.format(x) for x in keyword_scores.tolist()])
fout.write( f.write('{} {} {}\n'.format(key, keyword_i, score_frames))
'{} {} {}\n'.format(key, keyword_i, score_frames))
print('Result saved to: {}'.format(scoring_conf['score_file'])) print('Result saved to: {}'.format(args.score_file))

@ -11,77 +11,88 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import os import os
import paddle import paddle
import yaml from yacs.config import CfgNode
from paddleaudio.utils import logger from paddleaudio.utils import logger
from paddleaudio.utils import Timer from paddleaudio.utils import Timer
from paddlespeech.kws.exps.mdtc.collate import collate_features from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel from paddlespeech.kws.models.mdtc import KWSModel
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--cfg_path", type=str, required=True)
args = parser.parse_args()
# yapf: enable
if __name__ == '__main__': if __name__ == '__main__':
parser = default_argument_parser()
args = parser.parse_args()
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
nranks = paddle.distributed.get_world_size() nranks = paddle.distributed.get_world_size()
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
local_rank = paddle.distributed.get_rank() local_rank = paddle.distributed.get_rank()
args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path))
with open(args.cfg_path, 'r') as f:
config = yaml.safe_load(f)
model_conf = config['model']
data_conf = config['data']
feat_conf = config['feature']
training_conf = config['training']
# Dataset # Dataset
ds_class = dynamic_import(data_conf['dataset']) ds_class = dynamic_import(config['dataset'])
train_ds = ds_class( train_ds = ds_class(
data_dir=data_conf['data_dir'], mode='train', **feat_conf) data_dir=config['data_dir'],
dev_ds = ds_class(data_dir=data_conf['data_dir'], mode='dev', **feat_conf) mode='train',
feat_type=config['feat_type'],
sample_rate=config['sample_rate'],
frame_shift=config['frame_shift'],
frame_length=config['frame_length'],
n_mels=config['n_mels'], )
dev_ds = ds_class(
data_dir=config['data_dir'],
mode='dev',
feat_type=config['feat_type'],
sample_rate=config['sample_rate'],
frame_shift=config['frame_shift'],
frame_length=config['frame_length'],
n_mels=config['n_mels'], )
train_sampler = paddle.io.DistributedBatchSampler( train_sampler = paddle.io.DistributedBatchSampler(
train_ds, train_ds,
batch_size=training_conf['batch_size'], batch_size=config['batch_size'],
shuffle=True, shuffle=True,
drop_last=False) drop_last=False)
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
train_ds, train_ds,
batch_sampler=train_sampler, batch_sampler=train_sampler,
num_workers=training_conf['num_workers'], num_workers=config['num_workers'],
return_list=True, return_list=True,
use_buffer_reader=True, use_buffer_reader=True,
collate_fn=collate_features, ) collate_fn=collate_features, )
# Model # Model
backbone_class = dynamic_import(model_conf['backbone']) backbone_class = dynamic_import(config['backbone'])
backbone = backbone_class(**model_conf['config']) backbone = backbone_class(
model = KWSModel(backbone=backbone, num_keywords=model_conf['num_keywords']) stack_num=config['stack_num'],
stack_size=config['stack_size'],
in_channels=config['in_channels'],
res_channels=config['res_channels'],
kernel_size=config['kernel_size'], )
model = KWSModel(backbone=backbone, num_keywords=config['num_keywords'])
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
clip = paddle.nn.ClipGradByGlobalNorm(training_conf['grad_clip']) clip = paddle.nn.ClipGradByGlobalNorm(config['grad_clip'])
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=training_conf['learning_rate'], learning_rate=config['learning_rate'],
weight_decay=training_conf['weight_decay'], weight_decay=config['weight_decay'],
parameters=model.parameters(), parameters=model.parameters(),
grad_clip=clip) grad_clip=clip)
criterion = max_pooling_loss criterion = max_pooling_loss
steps_per_epoch = len(train_sampler) steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * training_conf['epochs']) timer = Timer(steps_per_epoch * config['epochs'])
timer.start() timer.start()
for epoch in range(1, training_conf['epochs'] + 1): for epoch in range(1, config['epochs'] + 1):
model.train() model.train()
avg_loss = 0 avg_loss = 0
@ -107,15 +118,13 @@ if __name__ == '__main__':
timer.count() timer.count()
if (batch_idx + 1 if (batch_idx + 1) % config['log_freq'] == 0 and local_rank == 0:
) % training_conf['log_freq'] == 0 and local_rank == 0:
lr = optimizer.get_lr() lr = optimizer.get_lr()
avg_loss /= training_conf['log_freq'] avg_loss /= config['log_freq']
avg_acc = num_corrects / num_samples avg_acc = num_corrects / num_samples
print_msg = 'Epoch={}/{}, Step={}/{}'.format( print_msg = 'Epoch={}/{}, Step={}/{}'.format(
epoch, training_conf['epochs'], batch_idx + 1, epoch, config['epochs'], batch_idx + 1, steps_per_epoch)
steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss) print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc) print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format( print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(
@ -126,17 +135,17 @@ if __name__ == '__main__':
num_corrects = 0 num_corrects = 0
num_samples = 0 num_samples = 0
if epoch % training_conf[ if epoch % config[
'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0: 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
dev_sampler = paddle.io.BatchSampler( dev_sampler = paddle.io.BatchSampler(
dev_ds, dev_ds,
batch_size=training_conf['batch_size'], batch_size=config['batch_size'],
shuffle=False, shuffle=False,
drop_last=False) drop_last=False)
dev_loader = paddle.io.DataLoader( dev_loader = paddle.io.DataLoader(
dev_ds, dev_ds,
batch_sampler=dev_sampler, batch_sampler=dev_sampler,
num_workers=training_conf['num_workers'], num_workers=config['num_workers'],
return_list=True, return_list=True,
use_buffer_reader=True, use_buffer_reader=True,
collate_fn=collate_features, ) collate_fn=collate_features, )
@ -159,7 +168,7 @@ if __name__ == '__main__':
logger.eval(print_msg) logger.eval(print_msg)
# Save model # Save model
save_dir = os.path.join(training_conf['checkpoint_dir'], save_dir = os.path.join(config['checkpoint_dir'],
'epoch_{}'.format(epoch)) 'epoch_{}'.format(epoch))
logger.info('Saving model checkpoint to {}'.format(save_dir)) logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(), paddle.save(model.state_dict(),

Loading…
Cancel
Save