From 1632af770687459781e9ec09e52b6e438bc334df Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Thu, 23 Dec 2021 12:28:58 +0800 Subject: [PATCH] Update examples/esc50. (#1203) --- docs/source/cls/custom_dataset.md | 16 +++- examples/esc50/README.md | 60 ++++++++------- examples/esc50/cls0/conf/panns.yaml | 36 +++++++++ examples/esc50/cls0/local/export.sh | 4 +- examples/esc50/cls0/local/infer.sh | 9 +-- examples/esc50/cls0/local/train.sh | 19 +---- examples/esc50/cls0/run.sh | 20 ++--- paddleaudio/datasets/dataset.py | 8 +- paddlespeech/cls/exps/panns/predict.py | 59 ++++++++------- paddlespeech/cls/exps/panns/train.py | 101 +++++++++++++------------ 10 files changed, 188 insertions(+), 144 deletions(-) create mode 100644 examples/esc50/cls0/conf/panns.yaml diff --git a/docs/source/cls/custom_dataset.md b/docs/source/cls/custom_dataset.md index 0b9ed726..56432cca 100644 --- a/docs/source/cls/custom_dataset.md +++ b/docs/source/cls/custom_dataset.md @@ -17,8 +17,6 @@ Here is an example to build your custom dataset in `custom_dataset.py`: from paddleaudio.datasets.dataset import AudioClassificationDataset class CustomDataset(AudioClassificationDataset): - # All *.wav file with same sample rate 16k/24k/32k/44k. - sample_rate = 16000 meta_file = '/PATH/TO/META_FILE.txt' # List all the class labels label_list = [ @@ -54,8 +52,20 @@ from paddleaudio.features import LogMelSpectrogram from custom_dataset import CustomDataset +# Feature config should be align with pretrained model +feat_conf = { + 'sr': 32000, + 'n_fft': 1024, + 'hop_length': 320, + 'window': 'hann', + 'win_length': 1024, + 'f_min': 50.0, + 'f_max': 14000.0, + 'n_mels': 64, +} + train_ds = CustomDataset() -feature_extractor = LogMelSpectrogram(sr=train_ds.sample_rate) +feature_extractor = LogMelSpectrogram(**feat_conf) train_sampler = paddle.io.DistributedBatchSampler( train_ds, batch_size=4, shuffle=True, drop_last=False) diff --git a/examples/esc50/README.md b/examples/esc50/README.md index 66409754..20fecf0c 100644 --- a/examples/esc50/README.md +++ b/examples/esc50/README.md @@ -17,21 +17,30 @@ PaddleAudio提供了PANNs的CNN14、CNN10和CNN6的预训练模型,可供用 - CNN6: 该模型主要包含4个卷积层和2个全连接层,模型参数的数量为4.5M,embbedding维度是512。 +## 数据集 + +[ESC-50: Dataset for Environmental Sound Classification](https://github.com/karolpiczak/ESC-50) 是一个包含有 2000 个带标签的环境声音样本,音频样本采样率为 44,100Hz 的单通道音频文件,所有样本根据标签被划分为 50 个类别,每个类别有 40 个样本。 + +## 模型指标 + +根据 `ESC-50` 提供的fold信息,对数据集进行 5-fold 的 fine-tune 训练和评估,平均准确率如下: + +|Model|Acc| +|--|--| +|CNN14| 0.950 + ## 快速开始 ### 模型训练 -以环境声音分类数据集`ESC50`为示例,运行下面的命令,可在训练集上进行模型的finetune,支持单机的单卡训练和多卡训练。 +运行下面的命令,可在训练集上进行模型的finetune,支持单机的单卡训练和多卡训练。 启动训练: ```shell -$ CUDA_VISIBLE_DEVICES=0 ./run.sh 1 +$ CUDA_VISIBLE_DEVICES=0 ./run.sh 1 conf/panns.yaml ``` -`paddlespeech/cls/exps/panns/train.py` 脚本中可支持配置的参数: - -- `device`: 指定模型预测时使用的设备。 -- `feat_backend`: 选择提取特征的后端,可选`'numpy'`或`'paddle'`,默认为`'numpy'`。 +训练的参数可在 `conf/panns.yaml` 的 `training` 中配置,其中: - `epochs`: 训练轮次,默认为50。 - `learning_rate`: Fine-tune的学习率;默认为5e-5。 - `batch_size`: 批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为16。 @@ -40,36 +49,31 @@ $ CUDA_VISIBLE_DEVICES=0 ./run.sh 1 - `save_freq`: 训练过程中的模型保存频率,默认为10。 - `log_freq`: 训练过程中的信息打印频率,默认为10。 -示例代码中使用的预训练模型为`CNN14`,如果想更换为其他预训练模型,可通过以下方式执行: -```python -from paddleaudio.datasets import ESC50 -from paddlespeech.cls.models import SoundClassifier -from paddlespeech.cls.models import cnn14, cnn10, cnn6 - +示例代码中使用的预训练模型为`CNN14`,如果想更换为其他预训练模型,可通过修改 `conf/panns.yaml` 的 `model` 中配置: +```yaml # CNN14 -backbone = cnn14(pretrained=True, extract_embedding=True) -model = SoundClassifier(backbone, num_class=len(ESC50.label_list)) - +model: + backbone: 'paddlespeech.cls.models:cnn14' +``` +```yaml # CNN10 -backbone = cnn10(pretrained=True, extract_embedding=True) -model = SoundClassifier(backbone, num_class=len(ESC50.label_list)) - +model: + backbone: 'paddlespeech.cls.models:cnn10' +``` +```yaml # CNN6 -backbone = cnn6(pretrained=True, extract_embedding=True) -model = SoundClassifier(backbone, num_class=len(ESC50.label_list)) +model: + backbone: 'paddlespeech.cls.models:cnn6' ``` ### 模型预测 ```shell -$ CUDA_VISIBLE_DEVICES=0 ./run.sh 2 +$ CUDA_VISIBLE_DEVICES=0 ./run.sh 2 conf/panns.yaml ``` -`paddlespeech/cls/exps/panns/predict.py` 脚本中可支持配置的参数: - -- `device`: 指定模型预测时使用的设备。 -- `wav`: 指定预测的音频文件。 -- `feat_backend`: 选择提取特征的后端,可选`'numpy'`或`'paddle'`,默认为`'numpy'`。 +训练的参数可在 `conf/panns.yaml` 的 `predicting` 中配置,其中: +- `audio_file`: 指定预测的音频文件。 - `top_k`: 预测显示的top k标签的得分,默认为1。 - `checkpoint`: 模型参数checkpoint文件。 @@ -88,7 +92,7 @@ Cat: 6.579841738130199e-06 模型训练结束后,可以将已保存的动态图参数导出成静态图的模型和参数,然后实施静态图的部署。 ```shell -$ CUDA_VISIBLE_DEVICES=0 ./run.sh 3 +$ CUDA_VISIBLE_DEVICES=0 ./run.sh 3 ./checkpoint/epoch_50/model.pdparams ./export ``` `paddlespeech/cls/exps/panns/export_model.py` 脚本中可支持配置的参数: @@ -109,7 +113,7 @@ export `paddlespeech/cls/exps/panns/deploy/predict.py` 脚本使用了`paddle.inference`模块下的api,提供了python端部署的示例: ```shell -$ CUDA_VISIBLE_DEVICES=0 ./run.sh 4 +$ CUDA_VISIBLE_DEVICES=0 ./run.sh 4 cpu ./export /audio/dog.wav ``` `paddlespeech/cls/exps/panns/deploy/predict.py` 脚本中可支持配置的主要参数: diff --git a/examples/esc50/cls0/conf/panns.yaml b/examples/esc50/cls0/conf/panns.yaml new file mode 100644 index 00000000..3a9d42aa --- /dev/null +++ b/examples/esc50/cls0/conf/panns.yaml @@ -0,0 +1,36 @@ +data: + dataset: 'paddleaudio.datasets:ESC50' + num_classes: 50 + train: + mode: 'train' + split: 1 + dev: + mode: 'dev' + split: 1 + +model: + backbone: 'paddlespeech.cls.models:cnn14' + +feature: + sr: 32000 + n_fft: 1024 + hop_length: 320 + window: 'hann' + win_length: 1024 + f_min: 50.0 + f_max: 14000.0 + n_mels: 64 + +training: + epochs: 50 + learning_rate: 0.00005 + num_workers: 2 + batch_size: 16 + checkpoint_dir: './checkpoint' + save_freq: 10 + log_freq: 10 + +predicting: + audio_file: '/audio/dog.wav' + top_k: 10 + checkpoint: './checkpoint/epoch_50/model.pdparams' \ No newline at end of file diff --git a/examples/esc50/cls0/local/export.sh b/examples/esc50/cls0/local/export.sh index 160dc743..9c854a19 100755 --- a/examples/esc50/cls0/local/export.sh +++ b/examples/esc50/cls0/local/export.sh @@ -1,8 +1,8 @@ #!/bin/bash -ckpt_dir=$1 +ckpt=$1 output_dir=$2 python3 ${BIN_DIR}/export_model.py \ ---checkpoint ${ckpt_dir}/model.pdparams \ +--checkpoint ${ckpt} \ --output_dir ${output_dir} diff --git a/examples/esc50/cls0/local/infer.sh b/examples/esc50/cls0/local/infer.sh index bc03d681..25d595be 100755 --- a/examples/esc50/cls0/local/infer.sh +++ b/examples/esc50/cls0/local/infer.sh @@ -1,11 +1,4 @@ #!/bin/bash -audio_file=$1 -ckpt_dir=$2 -feat_backend=$3 - python3 ${BIN_DIR}/predict.py \ ---wav ${audio_file} \ ---feat_backend ${feat_backend} \ ---top_k 10 \ ---checkpoint ${ckpt_dir}/model.pdparams +--cfg_path=$1 diff --git a/examples/esc50/cls0/local/train.sh b/examples/esc50/cls0/local/train.sh index 0f0f3d09..cab547b8 100755 --- a/examples/esc50/cls0/local/train.sh +++ b/examples/esc50/cls0/local/train.sh @@ -1,25 +1,12 @@ #!/bin/bash ngpu=$1 -feat_backend=$2 - -num_epochs=50 -batch_size=16 -ckpt_dir=./checkpoint -save_freq=10 +cfg_path=$2 if [ ${ngpu} -gt 0 ]; then python3 -m paddle.distributed.launch --gpus $CUDA_VISIBLE_DEVICES ${BIN_DIR}/train.py \ - --epochs ${num_epochs} \ - --feat_backend ${feat_backend} \ - --batch_size ${batch_size} \ - --checkpoint_dir ${ckpt_dir} \ - --save_freq ${save_freq} + --cfg_path ${cfg_path} else python3 ${BIN_DIR}/train.py \ - --epochs ${num_epochs} \ - --feat_backend ${feat_backend} \ - --batch_size ${batch_size} \ - --checkpoint_dir ${ckpt_dir} \ - --save_freq ${save_freq} + --cfg_path ${cfg_path} fi diff --git a/examples/esc50/cls0/run.sh b/examples/esc50/cls0/run.sh index 7283aa8d..0e407b40 100755 --- a/examples/esc50/cls0/run.sh +++ b/examples/esc50/cls0/run.sh @@ -6,28 +6,30 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') stage=$1 stop_stage=100 -feat_backend=numpy -audio_file=~/cat.wav -ckpt_dir=./checkpoint/epoch_50 -output_dir=./export -infer_device=cpu if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - ./local/train.sh ${ngpu} ${feat_backend} || exit -1 + cfg_path=$2 + ./local/train.sh ${ngpu} ${cfg_path} || exit -1 exit 0 fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - ./local/infer.sh ${audio_file} ${ckpt_dir} ${feat_backend} || exit -1 + cfg_path=$2 + ./local/infer.sh ${cfg_path} || exit -1 exit 0 fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - ./local/export.sh ${ckpt_dir} ${output_dir} || exit -1 + ckpt=$2 + output_dir=$3 + ./local/export.sh ${ckpt} ${output_dir} || exit -1 exit 0 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - ./local/static_model_infer.sh ${infer_device} ${output_dir} ${audio_file} || exit -1 + infer_device=$2 + graph_dir=$3 + audio_file=$4 + ./local/static_model_infer.sh ${infer_device} ${graph_dir} ${audio_file} || exit -1 exit 0 fi diff --git a/paddleaudio/datasets/dataset.py b/paddleaudio/datasets/dataset.py index fb521bea..7a57fd6c 100644 --- a/paddleaudio/datasets/dataset.py +++ b/paddleaudio/datasets/dataset.py @@ -36,6 +36,7 @@ class AudioClassificationDataset(paddle.io.Dataset): files: List[str], labels: List[int], feat_type: str='raw', + sample_rate: int=None, **kwargs): """ Ags: @@ -55,6 +56,7 @@ class AudioClassificationDataset(paddle.io.Dataset): self.labels = labels self.feat_type = feat_type + self.sample_rate = sample_rate self.feat_config = kwargs # Pass keyword arguments to customize feature config def _get_data(self, input_file: str): @@ -63,7 +65,11 @@ class AudioClassificationDataset(paddle.io.Dataset): def _convert_to_record(self, idx): file, label = self.files[idx], self.labels[idx] - waveform, sample_rate = load_audio(file) + if self.sample_rate is None: + waveform, sample_rate = load_audio(file) + else: + waveform, sample_rate = load_audio(file, sr=self.sample_rate) + feat_func = feat_funcs[self.feat_type] record = {} diff --git a/paddlespeech/cls/exps/panns/predict.py b/paddlespeech/cls/exps/panns/predict.py index 9cfd8b6c..ffe42d39 100644 --- a/paddlespeech/cls/exps/panns/predict.py +++ b/paddlespeech/cls/exps/panns/predict.py @@ -12,58 +12,61 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os -import numpy as np import paddle import paddle.nn.functional as F +import yaml from paddleaudio.backends import load as load_audio -from paddleaudio.datasets import ESC50 from paddleaudio.features import LogMelSpectrogram -from paddleaudio.features import melspectrogram -from paddlespeech.cls.models import cnn14 +from paddleaudio.utils import logger from paddlespeech.cls.models import SoundClassifier +from paddlespeech.s2t.utils.dynamic_import import dynamic_import # yapf: disable parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--wav", type=str, required=True, help="Audio file to infer.") -parser.add_argument("--feat_backend", type=str, choices=['numpy', 'paddle'], default='numpy', help="Choose backend to extract features from audio files.") -parser.add_argument("--top_k", type=int, default=1, help="Show top k predicted results") -parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint of model.") +parser.add_argument("--cfg_path", type=str, required=True) args = parser.parse_args() # yapf: enable -def extract_features(file: str, feat_backend: str='numpy', - **kwargs) -> paddle.Tensor: - waveform, sr = load_audio(file, sr=None) - - if args.feat_backend == 'numpy': - feat = melspectrogram(waveform, sr, **kwargs).transpose() - feat = np.expand_dims(feat, 0) - feat = paddle.to_tensor(feat) - else: - feature_extractor = LogMelSpectrogram(sr=sr, **kwargs) - feat = feature_extractor(paddle.to_tensor(waveform).unsqueeze(0)) - feat = paddle.transpose(feat, [0, 2, 1]) +def extract_features(file: str, **feat_conf) -> paddle.Tensor: + file = os.path.abspath(os.path.expanduser(file)) + waveform, _ = load_audio(file, sr=feat_conf['sr']) + feature_extractor = LogMelSpectrogram(**feat_conf) + feat = feature_extractor(paddle.to_tensor(waveform).unsqueeze(0)) + feat = paddle.transpose(feat, [0, 2, 1]) return feat if __name__ == '__main__': + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) + + model_conf = config['model'] + data_conf = config['data'] + feat_conf = config['feature'] + predicting_conf = config['predicting'] + + ds_class = dynamic_import(data_conf['dataset']) + backbone_class = dynamic_import(model_conf['backbone']) + model = SoundClassifier( - backbone=cnn14(pretrained=False, extract_embedding=True), - num_class=len(ESC50.label_list)) - model.set_state_dict(paddle.load(args.checkpoint)) + backbone=backbone_class(pretrained=False, extract_embedding=True), + num_class=len(ds_class.label_list)) + model.set_state_dict(paddle.load(predicting_conf['checkpoint'])) model.eval() - feat = extract_features(args.wav, args.feat_backend) + feat = extract_features(predicting_conf['audio_file'], **feat_conf) logits = model(feat) probs = F.softmax(logits, axis=1).numpy() sorted_indices = (-probs[0]).argsort() - msg = f'[{args.wav}]\n' - for idx in sorted_indices[:args.top_k]: - msg += f'{ESC50.label_list[idx]}: {probs[0][idx]}\n' - print(msg) + msg = f"[{predicting_conf['audio_file']}]\n" + for idx in sorted_indices[:predicting_conf['top_k']]: + msg += f'{ds_class.label_list[idx]}: {probs[0][idx]}\n' + logger.info(msg) diff --git a/paddlespeech/cls/exps/panns/train.py b/paddlespeech/cls/exps/panns/train.py index 12130978..7e292214 100644 --- a/paddlespeech/cls/exps/panns/train.py +++ b/paddlespeech/cls/exps/panns/train.py @@ -15,24 +15,17 @@ import argparse import os import paddle +import yaml -from paddleaudio.datasets import ESC50 from paddleaudio.features import LogMelSpectrogram from paddleaudio.utils import logger from paddleaudio.utils import Timer -from paddlespeech.cls.models import cnn14 from paddlespeech.cls.models import SoundClassifier +from paddlespeech.s2t.utils.dynamic_import import dynamic_import # yapf: disable parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--epochs", type=int, default=50, help="Number of epoches for fine-tuning.") -parser.add_argument("--feat_backend", type=str, choices=['numpy', 'paddle'], default='numpy', help="Choose backend to extract features from audio files.") -parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") -parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.") -parser.add_argument("--num_workers", type=int, default=0, help="Number of workers in dataloader.") -parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to save model checkpoints.") -parser.add_argument("--save_freq", type=int, default=10, help="Save checkpoint every n epoch.") -parser.add_argument("--log_freq", type=int, default=10, help="Log the training infomation every n steps.") +parser.add_argument("--cfg_path", type=str, required=True) args = parser.parse_args() # yapf: enable @@ -42,50 +35,60 @@ if __name__ == "__main__": paddle.distributed.init_parallel_env() local_rank = paddle.distributed.get_rank() - backbone = cnn14(pretrained=True, extract_embedding=True) - model = SoundClassifier(backbone, num_class=len(ESC50.label_list)) - model = paddle.DataParallel(model) - optimizer = paddle.optimizer.Adam( - learning_rate=args.learning_rate, parameters=model.parameters()) - criterion = paddle.nn.loss.CrossEntropyLoss() + args.cfg_path = os.path.abspath(os.path.expanduser(args.cfg_path)) + with open(args.cfg_path, 'r') as f: + config = yaml.safe_load(f) - if args.feat_backend == 'numpy': - train_ds = ESC50(mode='train', feat_type='melspectrogram') - dev_ds = ESC50(mode='dev', feat_type='melspectrogram') - else: - train_ds = ESC50(mode='train') - dev_ds = ESC50(mode='dev') - feature_extractor = LogMelSpectrogram(sr=16000) + model_conf = config['model'] + data_conf = config['data'] + feat_conf = config['feature'] + training_conf = config['training'] + # Dataset + ds_class = dynamic_import(data_conf['dataset']) + train_ds = ds_class(**data_conf['train']) + dev_ds = ds_class(**data_conf['dev']) train_sampler = paddle.io.DistributedBatchSampler( - train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False) + train_ds, + batch_size=training_conf['batch_size'], + shuffle=True, + drop_last=False) train_loader = paddle.io.DataLoader( train_ds, batch_sampler=train_sampler, - num_workers=args.num_workers, + num_workers=training_conf['num_workers'], return_list=True, use_buffer_reader=True, ) + # Feature + feature_extractor = LogMelSpectrogram(**feat_conf) + + # Model + backbone_class = dynamic_import(model_conf['backbone']) + backbone = backbone_class(pretrained=True, extract_embedding=True) + model = SoundClassifier(backbone, num_class=data_conf['num_classes']) + model = paddle.DataParallel(model) + optimizer = paddle.optimizer.Adam( + learning_rate=training_conf['learning_rate'], + parameters=model.parameters()) + criterion = paddle.nn.loss.CrossEntropyLoss() + steps_per_epoch = len(train_sampler) - timer = Timer(steps_per_epoch * args.epochs) + timer = Timer(steps_per_epoch * training_conf['epochs']) timer.start() - for epoch in range(1, args.epochs + 1): + for epoch in range(1, training_conf['epochs'] + 1): model.train() avg_loss = 0 num_corrects = 0 num_samples = 0 for batch_idx, batch in enumerate(train_loader): - if args.feat_backend == 'numpy': - feats, labels = batch - else: - waveforms, labels = batch - feats = feature_extractor( - waveforms - ) # Need a padding when lengths of waveforms differ in a batch. - feats = paddle.transpose(feats, - [0, 2, 1]) # To [N, length, n_mels] + waveforms, labels = batch + feats = feature_extractor( + waveforms + ) # Need a padding when lengths of waveforms differ in a batch. + feats = paddle.transpose(feats, [0, 2, 1]) # To [N, length, n_mels] logits = model(feats) @@ -107,13 +110,15 @@ if __name__ == "__main__": timer.count() - if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0: + if (batch_idx + 1 + ) % training_conf['log_freq'] == 0 and local_rank == 0: lr = optimizer.get_lr() - avg_loss /= args.log_freq + avg_loss /= training_conf['log_freq'] avg_acc = num_corrects / num_samples print_msg = 'Epoch={}/{}, Step={}/{}'.format( - epoch, args.epochs, batch_idx + 1, steps_per_epoch) + epoch, training_conf['epochs'], batch_idx + 1, + steps_per_epoch) print_msg += ' loss={:.4f}'.format(avg_loss) print_msg += ' acc={:.4f}'.format(avg_acc) print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format( @@ -124,16 +129,17 @@ if __name__ == "__main__": num_corrects = 0 num_samples = 0 - if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0: + if epoch % training_conf[ + 'save_freq'] == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0: dev_sampler = paddle.io.BatchSampler( dev_ds, - batch_size=args.batch_size, + batch_size=training_conf['batch_size'], shuffle=False, drop_last=False) dev_loader = paddle.io.DataLoader( dev_ds, batch_sampler=dev_sampler, - num_workers=args.num_workers, + num_workers=training_conf['num_workers'], return_list=True, ) model.eval() @@ -141,12 +147,9 @@ if __name__ == "__main__": num_samples = 0 with logger.processing('Evaluation on validation dataset'): for batch_idx, batch in enumerate(dev_loader): - if args.feat_backend == 'numpy': - feats, labels = batch - else: - waveforms, labels = batch - feats = feature_extractor(waveforms) - feats = paddle.transpose(feats, [0, 2, 1]) + waveforms, labels = batch + feats = feature_extractor(waveforms) + feats = paddle.transpose(feats, [0, 2, 1]) logits = model(feats) @@ -160,7 +163,7 @@ if __name__ == "__main__": logger.eval(print_msg) # Save model - save_dir = os.path.join(args.checkpoint_dir, + save_dir = os.path.join(training_conf['checkpoint_dir'], 'epoch_{}'.format(epoch)) logger.info('Saving model checkpoint to {}'.format(save_dir)) paddle.save(model.state_dict(),