commit
b031ee43c4
@ -0,0 +1,44 @@
|
|||||||
|
###########################################################
|
||||||
|
# DATA SETTING #
|
||||||
|
###########################################################
|
||||||
|
dataset_type: Ernie
|
||||||
|
train_path: data/iwslt2012_zh/train.txt
|
||||||
|
dev_path: data/iwslt2012_zh/dev.txt
|
||||||
|
test_path: data/iwslt2012_zh/test.txt
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 2
|
||||||
|
data_params:
|
||||||
|
pretrained_token: ernie-1.0
|
||||||
|
punc_path: data/iwslt2012_zh/punc_vocab
|
||||||
|
seq_len: 100
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
model_type: ErnieLinear
|
||||||
|
model:
|
||||||
|
pretrained_token: ernie-1.0
|
||||||
|
num_classes: 4
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER SETTING #
|
||||||
|
###########################################################
|
||||||
|
optimizer_params:
|
||||||
|
weight_decay: 1.0e-6 # weight decay coefficient.
|
||||||
|
|
||||||
|
scheduler_params:
|
||||||
|
learning_rate: 1.0e-5 # learning rate.
|
||||||
|
gamma: 1.0 # scheduler gamma.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# TRAINING SETTING #
|
||||||
|
###########################################################
|
||||||
|
max_epoch: 20
|
||||||
|
num_snapshots: 5
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
num_snapshots: 10 # max number of snapshots to keep while training
|
||||||
|
seed: 42 # random seed for paddle, random, and np.random
|
@ -1,36 +0,0 @@
|
|||||||
data:
|
|
||||||
dataset_type: Ernie
|
|
||||||
train_path: data/iwslt2012_zh/train.txt
|
|
||||||
dev_path: data/iwslt2012_zh/dev.txt
|
|
||||||
test_path: data/iwslt2012_zh/test.txt
|
|
||||||
data_params:
|
|
||||||
pretrained_token: ernie-1.0
|
|
||||||
punc_path: data/iwslt2012_zh/punc_vocab
|
|
||||||
seq_len: 100
|
|
||||||
batch_size: 64
|
|
||||||
sortagrad: True
|
|
||||||
shuffle_method: batch_shuffle
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
checkpoint:
|
|
||||||
kbest_n: 5
|
|
||||||
latest_n: 10
|
|
||||||
metric_type: F1
|
|
||||||
|
|
||||||
model_type: ErnieLinear
|
|
||||||
|
|
||||||
model_params:
|
|
||||||
pretrained_token: ernie-1.0
|
|
||||||
num_classes: 4
|
|
||||||
|
|
||||||
training:
|
|
||||||
n_epoch: 20
|
|
||||||
lr: !!float 1e-5
|
|
||||||
lr_decay: 1.0
|
|
||||||
weight_decay: !!float 1e-06
|
|
||||||
global_grad_clip: 5.0
|
|
||||||
log_interval: 10
|
|
||||||
log_path: log/train_ernie_linear.log
|
|
||||||
|
|
||||||
testing:
|
|
||||||
log_path: log/test_ernie_linear.log
|
|
@ -1,23 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
if [ $# != 2 ]; then
|
|
||||||
echo "usage: ${0} ckpt_dir avg_num"
|
|
||||||
exit -1
|
|
||||||
fi
|
|
||||||
|
|
||||||
ckpt_dir=${1}
|
|
||||||
average_num=${2}
|
|
||||||
decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
|
|
||||||
|
|
||||||
python3 -u ${BIN_DIR}/avg_model.py \
|
|
||||||
--dst_model ${decode_checkpoint} \
|
|
||||||
--ckpt_dir ${ckpt_dir} \
|
|
||||||
--num ${average_num} \
|
|
||||||
--val_best
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in avg ckpt!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
exit 0
|
|
@ -0,0 +1,12 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
text=$4
|
||||||
|
ckpt_prefix=${ckpt_name%.*}
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/punc_restore.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--text=${text}
|
@ -1,26 +1,11 @@
|
|||||||
|
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
if [ $# != 2 ];then
|
|
||||||
echo "usage: ${0} config_path ckpt_path_prefix"
|
|
||||||
exit -1
|
|
||||||
fi
|
|
||||||
|
|
||||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
|
||||||
echo "using $ngpu gpus..."
|
|
||||||
|
|
||||||
config_path=$1
|
config_path=$1
|
||||||
ckpt_prefix=$2
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
python3 -u ${BIN_DIR}/test.py \
|
|
||||||
--ngpu 1 \
|
|
||||||
--config ${config_path} \
|
|
||||||
--result_file ${ckpt_prefix}.rsl \
|
|
||||||
--checkpoint_path ${ckpt_prefix}
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
ckpt_prefix=${ckpt_name%.*}
|
||||||
echo "Failed in evaluation!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
exit 0
|
python3 ${BIN_DIR}/test.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--checkpoint=${train_output_path}/checkpoints/${ckpt_name}
|
||||||
|
@ -1,28 +1,9 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
if [ $# != 3 ];then
|
|
||||||
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name log_dir"
|
|
||||||
exit -1
|
|
||||||
fi
|
|
||||||
|
|
||||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
|
||||||
echo "using $ngpu gpus..."
|
|
||||||
|
|
||||||
config_path=$1
|
config_path=$1
|
||||||
ckpt_name=$2
|
train_output_path=$2
|
||||||
log_dir=$3
|
|
||||||
|
|
||||||
mkdir -p exp
|
|
||||||
|
|
||||||
python3 -u ${BIN_DIR}/train.py \
|
|
||||||
--ngpu ${ngpu} \
|
|
||||||
--config ${config_path} \
|
|
||||||
--output_dir exp/${ckpt_name} \
|
|
||||||
--log_dir ${log_dir}
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in training!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
exit 0
|
python3 ${BIN_DIR}/train.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--output-dir=${train_output_path} \
|
||||||
|
--ngpu=1
|
||||||
|
@ -1,40 +1,35 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
if [ $# -ne 4 ]; then
|
gpus=0,1
|
||||||
echo "usage: bash ./run.sh stage gpu train_config avg_num"
|
stage=0
|
||||||
echo "eg: bash ./run.sh 1 0 train_config 1"
|
|
||||||
exit -1
|
|
||||||
fi
|
|
||||||
|
|
||||||
stage=$1
|
|
||||||
stop_stage=100
|
stop_stage=100
|
||||||
gpus=$2
|
|
||||||
conf_path=$3
|
|
||||||
avg_num=$4
|
|
||||||
avg_ckpt=avg_${avg_num}
|
|
||||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
|
||||||
log_dir=log
|
|
||||||
|
|
||||||
source path.sh ${ckpt}
|
conf_path=conf/default.yaml
|
||||||
|
train_output_path=exp/default
|
||||||
|
ckpt_name=snapshot_iter_12840.pdz
|
||||||
|
text=今天的天气真不错啊你下午有空吗我想约你一起去吃饭
|
||||||
|
|
||||||
|
# with the following command, you can choose the stage range you want to run
|
||||||
|
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||||
|
# this can not be mixed use with `$1`, `$2` ...
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||||
|
|
||||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
# prepare data
|
# prepare data
|
||||||
bash ./local/data.sh
|
./local/data.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
# train model, all `ckpt` under `exp` dir
|
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||||
CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${conf_path} ${ckpt} ${log_dir}
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
# avg n best model
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
bash ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
# test ckpt avg_n
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/punc_restore.sh ${conf_path} ${train_output_path} ${ckpt_name} ${text}|| exit -1
|
||||||
CUDA_VISIBLE_DEVICES=${gpus} bash ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
|
|
||||||
fi
|
fi
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
@ -0,0 +1,110 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import yaml
|
||||||
|
from paddlenlp.transformers import ErnieTokenizer
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.text.models.ernie_linear import ErnieLinear
|
||||||
|
|
||||||
|
DefinedClassifier = {
|
||||||
|
'ErnieLinear': ErnieLinear,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_text(text, punc_list):
|
||||||
|
text = text.lower()
|
||||||
|
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
|
||||||
|
text = re.sub(f'[{"".join([p for p in punc_list][1:])}]', '', text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(text, punc_list):
|
||||||
|
clean_text = _clean_text(text, punc_list)
|
||||||
|
assert len(clean_text) > 0, f'Invalid input string: {text}'
|
||||||
|
tokenized_input = tokenizer(
|
||||||
|
list(clean_text), return_length=True, is_split_into_words=True)
|
||||||
|
_inputs = dict()
|
||||||
|
_inputs['input_ids'] = tokenized_input['input_ids']
|
||||||
|
_inputs['seg_ids'] = tokenized_input['token_type_ids']
|
||||||
|
_inputs['seq_len'] = tokenized_input['seq_len']
|
||||||
|
return _inputs
|
||||||
|
|
||||||
|
|
||||||
|
def test(args):
|
||||||
|
with open(args.config) as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
punc_list = []
|
||||||
|
with open(config["data_params"]["punc_path"], 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
punc_list.append(line.strip())
|
||||||
|
|
||||||
|
model = DefinedClassifier[config["model_type"]](**config["model"])
|
||||||
|
state_dict = paddle.load(args.checkpoint)
|
||||||
|
model.set_state_dict(state_dict["main_params"])
|
||||||
|
model.eval()
|
||||||
|
_inputs = preprocess(args.text, punc_list)
|
||||||
|
seq_len = _inputs['seq_len']
|
||||||
|
input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
|
||||||
|
seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)
|
||||||
|
logits, _ = model(input_ids, seg_ids)
|
||||||
|
preds = paddle.argmax(logits, axis=-1).squeeze(0)
|
||||||
|
tokens = tokenizer.convert_ids_to_tokens(
|
||||||
|
_inputs['input_ids'][1:seq_len - 1])
|
||||||
|
labels = preds[1:seq_len - 1].tolist()
|
||||||
|
assert len(tokens) == len(labels)
|
||||||
|
# add 0 for non punc
|
||||||
|
punc_list = [0] + punc_list
|
||||||
|
text = ''
|
||||||
|
for t, l in zip(tokens, labels):
|
||||||
|
text += t
|
||||||
|
if l != 0: # Non punc.
|
||||||
|
text += punc_list[l]
|
||||||
|
print("Punctuation Restoration Result:", text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
parser = argparse.ArgumentParser(description="Run Punctuation Restoration.")
|
||||||
|
parser.add_argument("--config", type=str, help="ErnieLinear config file.")
|
||||||
|
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
|
||||||
|
parser.add_argument("--text", type=str, help="raw text to be restored.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.ngpu == 0:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
elif args.ngpu > 0:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
else:
|
||||||
|
print("ngpu should >= 0 !")
|
||||||
|
|
||||||
|
test(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,123 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
from paddle.optimizer.lr import LRScheduler
|
||||||
|
from sklearn.metrics import f1_score
|
||||||
|
|
||||||
|
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||||
|
from paddlespeech.t2s.training.reporter import report
|
||||||
|
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieLinearUpdater(StandardUpdater):
|
||||||
|
def __init__(self,
|
||||||
|
model: Layer,
|
||||||
|
criterion: Layer,
|
||||||
|
scheduler: LRScheduler,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
output_dir=None):
|
||||||
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||||
|
self.model = model
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
self.criterion = criterion
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
|
input, label = batch
|
||||||
|
label = paddle.reshape(label, shape=[-1])
|
||||||
|
y, logit = self.model(input)
|
||||||
|
pred = paddle.argmax(logit, axis=1)
|
||||||
|
|
||||||
|
loss = self.criterion(y, label)
|
||||||
|
|
||||||
|
self.optimizer.clear_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
F1_score = f1_score(
|
||||||
|
label.numpy().tolist(), pred.numpy().tolist(), average="macro")
|
||||||
|
|
||||||
|
report("train/loss", float(loss))
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
report("train/F1_score", float(F1_score))
|
||||||
|
losses_dict["F1_score"] = float(F1_score)
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieLinearEvaluator(StandardEvaluator):
|
||||||
|
def __init__(self,
|
||||||
|
model: Layer,
|
||||||
|
criterion: Layer,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
output_dir=None):
|
||||||
|
super().__init__(model, dataloader)
|
||||||
|
self.model = model
|
||||||
|
self.criterion = criterion
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
self.msg = "Evaluate: "
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
|
input, label = batch
|
||||||
|
label = paddle.reshape(label, shape=[-1])
|
||||||
|
y, logit = self.model(input)
|
||||||
|
pred = paddle.argmax(logit, axis=1)
|
||||||
|
|
||||||
|
loss = self.criterion(y, label)
|
||||||
|
|
||||||
|
F1_score = f1_score(
|
||||||
|
label.numpy().tolist(), pred.numpy().tolist(), average="macro")
|
||||||
|
|
||||||
|
report("eval/loss", float(loss))
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
report("eval/F1_score", float(F1_score))
|
||||||
|
losses_dict["F1_score"] = float(F1_score)
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
self.logger.info(self.msg)
|
@ -1,524 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
import paddle.nn as nn
|
|
||||||
import pandas as pd
|
|
||||||
from paddle import distributed as dist
|
|
||||||
from paddle.io import DataLoader
|
|
||||||
from sklearn.metrics import classification_report
|
|
||||||
from sklearn.metrics import f1_score
|
|
||||||
from sklearn.metrics import precision_recall_fscore_support
|
|
||||||
|
|
||||||
from ...s2t.utils import layer_tools
|
|
||||||
from ...s2t.utils import mp_tools
|
|
||||||
from ...s2t.utils.checkpoint import Checkpoint
|
|
||||||
from ...text.models import ErnieLinear
|
|
||||||
from ...text.models.ernie_linear.dataset import PuncDataset
|
|
||||||
from ...text.models.ernie_linear.dataset import PuncDatasetFromErnieTokenizer
|
|
||||||
|
|
||||||
__all__ = ["Trainer", "Tester"]
|
|
||||||
|
|
||||||
DefinedClassifier = {
|
|
||||||
'ErnieLinear': ErnieLinear,
|
|
||||||
}
|
|
||||||
|
|
||||||
DefinedLoss = {
|
|
||||||
"ce": nn.CrossEntropyLoss,
|
|
||||||
}
|
|
||||||
|
|
||||||
DefinedDataset = {
|
|
||||||
'Punc': PuncDataset,
|
|
||||||
'Ernie': PuncDatasetFromErnieTokenizer,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer():
|
|
||||||
def __init__(self, config, args):
|
|
||||||
self.config = config
|
|
||||||
self.args = args
|
|
||||||
self.optimizer = None
|
|
||||||
self.output_dir = None
|
|
||||||
self.log_dir = None
|
|
||||||
self.checkpoint_dir = None
|
|
||||||
self.iteration = 0
|
|
||||||
self.epoch = 0
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
"""Setup the experiment.
|
|
||||||
"""
|
|
||||||
self.setup_log_dir()
|
|
||||||
self.setup_logger()
|
|
||||||
if self.args.ngpu > 0:
|
|
||||||
paddle.set_device('gpu')
|
|
||||||
else:
|
|
||||||
paddle.set_device('cpu')
|
|
||||||
if self.parallel:
|
|
||||||
self.init_parallel()
|
|
||||||
|
|
||||||
self.setup_output_dir()
|
|
||||||
self.dump_config()
|
|
||||||
self.setup_checkpointer()
|
|
||||||
|
|
||||||
self.setup_model()
|
|
||||||
|
|
||||||
self.setup_dataloader()
|
|
||||||
|
|
||||||
self.iteration = 0
|
|
||||||
self.epoch = 1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parallel(self):
|
|
||||||
"""A flag indicating whether the experiment should run with
|
|
||||||
multiprocessing.
|
|
||||||
"""
|
|
||||||
return self.args.ngpu > 1
|
|
||||||
|
|
||||||
def init_parallel(self):
|
|
||||||
"""Init environment for multiprocess training.
|
|
||||||
"""
|
|
||||||
dist.init_parallel_env()
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def save(self, tag=None, infos: dict=None):
|
|
||||||
"""Save checkpoint (model parameters and optimizer states).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
|
|
||||||
infos (dict, optional): meta data to save. Defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
infos = infos if infos else dict()
|
|
||||||
infos.update({
|
|
||||||
"step": self.iteration,
|
|
||||||
"epoch": self.epoch,
|
|
||||||
"lr": self.optimizer.get_lr()
|
|
||||||
})
|
|
||||||
self.checkpointer.save_parameters(self.checkpoint_dir, self.iteration
|
|
||||||
if tag is None else tag, self.model,
|
|
||||||
self.optimizer, infos)
|
|
||||||
|
|
||||||
def resume_or_scratch(self):
|
|
||||||
"""Resume from latest checkpoint at checkpoints in the output
|
|
||||||
directory or load a specified checkpoint.
|
|
||||||
|
|
||||||
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
|
||||||
resume training.
|
|
||||||
"""
|
|
||||||
scratch = None
|
|
||||||
infos = self.checkpointer.load_parameters(
|
|
||||||
self.model,
|
|
||||||
self.optimizer,
|
|
||||||
checkpoint_dir=self.checkpoint_dir,
|
|
||||||
checkpoint_path=self.args.checkpoint_path)
|
|
||||||
if infos:
|
|
||||||
# restore from ckpt
|
|
||||||
self.iteration = infos["step"]
|
|
||||||
self.epoch = infos["epoch"]
|
|
||||||
scratch = False
|
|
||||||
else:
|
|
||||||
self.iteration = 0
|
|
||||||
self.epoch = 0
|
|
||||||
scratch = True
|
|
||||||
|
|
||||||
return scratch
|
|
||||||
|
|
||||||
def new_epoch(self):
|
|
||||||
"""Reset the train loader seed and increment `epoch`.
|
|
||||||
"""
|
|
||||||
self.epoch += 1
|
|
||||||
if self.parallel:
|
|
||||||
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
"""The training process control by epoch."""
|
|
||||||
from_scratch = self.resume_or_scratch()
|
|
||||||
|
|
||||||
if from_scratch:
|
|
||||||
# save init model, i.e. 0 epoch
|
|
||||||
self.save(tag="init")
|
|
||||||
|
|
||||||
self.lr_scheduler.step(self.iteration)
|
|
||||||
if self.parallel:
|
|
||||||
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
||||||
self.punc_list = []
|
|
||||||
for i in range(len(self.train_loader.dataset.id2punc)):
|
|
||||||
self.punc_list.append(self.train_loader.dataset.id2punc[i])
|
|
||||||
while self.epoch < self.config["training"]["n_epoch"]:
|
|
||||||
self.model.train()
|
|
||||||
self.total_label_train = []
|
|
||||||
self.total_predict_train = []
|
|
||||||
try:
|
|
||||||
data_start_time = time.time()
|
|
||||||
for batch_index, batch in enumerate(self.train_loader):
|
|
||||||
dataload_time = time.time() - data_start_time
|
|
||||||
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
||||||
msg += "epoch: {}, ".format(self.epoch)
|
|
||||||
msg += "step: {}, ".format(self.iteration)
|
|
||||||
msg += "batch : {}/{}, ".format(batch_index + 1,
|
|
||||||
len(self.train_loader))
|
|
||||||
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
|
|
||||||
msg += "data time: {:>.3f}s, ".format(dataload_time)
|
|
||||||
self.train_batch(batch_index, batch, msg)
|
|
||||||
data_start_time = time.time()
|
|
||||||
# t = classification_report(
|
|
||||||
# self.total_label_train,
|
|
||||||
# self.total_predict_train,
|
|
||||||
# target_names=self.punc_list)
|
|
||||||
# self.logger.info(t)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(e)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
total_loss, F1_score = self.valid()
|
|
||||||
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
|
|
||||||
format(self.epoch, total_loss, F1_score))
|
|
||||||
|
|
||||||
self.save(
|
|
||||||
tag=self.epoch, infos={"val_loss": total_loss,
|
|
||||||
"F1": F1_score})
|
|
||||||
# step lr every epoch
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
self.new_epoch()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
"""The routine of the experiment after setup. This method is intended
|
|
||||||
to be used by the user.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.train()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self.logger.info("Training was aborted by keybord interrupt.")
|
|
||||||
self.save()
|
|
||||||
exit(-1)
|
|
||||||
finally:
|
|
||||||
self.destory()
|
|
||||||
self.logger.info("Training Done.")
|
|
||||||
|
|
||||||
def setup_output_dir(self):
|
|
||||||
"""Create a directory used for output.
|
|
||||||
"""
|
|
||||||
# output dir
|
|
||||||
output_dir = Path(self.args.output_dir).expanduser()
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
self.output_dir = output_dir
|
|
||||||
|
|
||||||
def setup_log_dir(self):
|
|
||||||
"""Create a directory used for logging.
|
|
||||||
"""
|
|
||||||
# log dir
|
|
||||||
log_dir = Path(self.args.log_dir).expanduser()
|
|
||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
self.log_dir = log_dir
|
|
||||||
|
|
||||||
def setup_checkpointer(self):
|
|
||||||
"""Create a directory used to save checkpoints into.
|
|
||||||
|
|
||||||
It is "checkpoints" inside the output directory.
|
|
||||||
"""
|
|
||||||
# checkpoint dir
|
|
||||||
self.checkpointer = Checkpoint(self.config["checkpoint"]["kbest_n"],
|
|
||||||
self.config["checkpoint"]["latest_n"])
|
|
||||||
|
|
||||||
checkpoint_dir = self.output_dir / "checkpoints"
|
|
||||||
checkpoint_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
self.checkpoint_dir = checkpoint_dir
|
|
||||||
|
|
||||||
def setup_logger(self):
|
|
||||||
LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
|
||||||
format_str = logging.Formatter(
|
|
||||||
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
|
||||||
)
|
|
||||||
logging.basicConfig(
|
|
||||||
filename=self.config["training"]["log_path"],
|
|
||||||
level=logging.INFO,
|
|
||||||
format=LOG_FORMAT)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
self.logger.setLevel(logging.INFO)
|
|
||||||
sh = logging.StreamHandler()
|
|
||||||
sh.setFormatter(format_str)
|
|
||||||
self.logger.addHandler(sh)
|
|
||||||
|
|
||||||
self.logger.info('info')
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def destory(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def dump_config(self):
|
|
||||||
"""Save the configuration used for this experiment.
|
|
||||||
|
|
||||||
It is saved in to ``config.yaml`` in the output directory at the
|
|
||||||
beginning of the experiment.
|
|
||||||
"""
|
|
||||||
with open(self.output_dir / "config.yaml", "wt") as f:
|
|
||||||
print(self.config, file=f)
|
|
||||||
|
|
||||||
def train_batch(self, batch_index, batch_data, msg):
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
input, label = batch_data
|
|
||||||
label = paddle.reshape(label, shape=[-1])
|
|
||||||
y, logit = self.model(input)
|
|
||||||
pred = paddle.argmax(logit, axis=1)
|
|
||||||
self.total_label_train.extend(label.numpy().tolist())
|
|
||||||
self.total_predict_train.extend(pred.numpy().tolist())
|
|
||||||
loss = self.crit(y, label)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
layer_tools.print_grads(self.model, print_func=None)
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.clear_grad()
|
|
||||||
iteration_time = time.time() - start
|
|
||||||
|
|
||||||
losses_np = {
|
|
||||||
"train_loss": float(loss),
|
|
||||||
}
|
|
||||||
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
|
||||||
msg += "batch size: {}, ".format(self.config["data"]["batch_size"])
|
|
||||||
msg += ", ".join("{}: {:>.6f}".format(k, v)
|
|
||||||
for k, v in losses_np.items())
|
|
||||||
self.logger.info(msg)
|
|
||||||
self.iteration += 1
|
|
||||||
|
|
||||||
@paddle.no_grad()
|
|
||||||
def valid(self):
|
|
||||||
self.logger.info(
|
|
||||||
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
|
||||||
self.model.eval()
|
|
||||||
valid_losses = defaultdict(list)
|
|
||||||
num_seen_utts = 1
|
|
||||||
total_loss = 0.0
|
|
||||||
valid_total_label = []
|
|
||||||
valid_total_predict = []
|
|
||||||
for i, batch in enumerate(self.valid_loader):
|
|
||||||
input, label = batch
|
|
||||||
label = paddle.reshape(label, shape=[-1])
|
|
||||||
y, logit = self.model(input)
|
|
||||||
pred = paddle.argmax(logit, axis=1)
|
|
||||||
valid_total_label.extend(label.numpy().tolist())
|
|
||||||
valid_total_predict.extend(pred.numpy().tolist())
|
|
||||||
loss = self.crit(y, label)
|
|
||||||
|
|
||||||
if paddle.isfinite(loss):
|
|
||||||
num_utts = batch[1].shape[0]
|
|
||||||
num_seen_utts += num_utts
|
|
||||||
total_loss += float(loss) * num_utts
|
|
||||||
valid_losses["val_loss"].append(float(loss))
|
|
||||||
|
|
||||||
if (i + 1) % self.config["training"]["log_interval"] == 0:
|
|
||||||
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
||||||
valid_dump["val_history_loss"] = total_loss / num_seen_utts
|
|
||||||
|
|
||||||
# logging
|
|
||||||
msg = f"Valid: Rank: {dist.get_rank()}, "
|
|
||||||
msg += "epoch: {}, ".format(self.epoch)
|
|
||||||
msg += "step: {}, ".format(self.iteration)
|
|
||||||
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
|
|
||||||
msg += ", ".join("{}: {:>.6f}".format(k, v)
|
|
||||||
for k, v in valid_dump.items())
|
|
||||||
self.logger.info(msg)
|
|
||||||
|
|
||||||
self.logger.info("Rank {} Val info val_loss {}".format(
|
|
||||||
dist.get_rank(), total_loss / num_seen_utts))
|
|
||||||
F1_score = f1_score(
|
|
||||||
valid_total_label, valid_total_predict, average="macro")
|
|
||||||
return total_loss / num_seen_utts, F1_score
|
|
||||||
|
|
||||||
def setup_model(self):
|
|
||||||
config = self.config
|
|
||||||
|
|
||||||
model = DefinedClassifier[self.config["model_type"]](
|
|
||||||
**self.config["model_params"])
|
|
||||||
self.crit = DefinedLoss[self.config["loss_type"]](**self.config[
|
|
||||||
"loss"]) if "loss_type" in self.config else DefinedLoss["ce"]()
|
|
||||||
|
|
||||||
if self.parallel:
|
|
||||||
model = paddle.DataParallel(model)
|
|
||||||
|
|
||||||
# self.logger.info(f"{model}")
|
|
||||||
# layer_tools.print_params(model, self.logger.info)
|
|
||||||
|
|
||||||
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
|
|
||||||
learning_rate=config["training"]["lr"],
|
|
||||||
gamma=config["training"]["lr_decay"],
|
|
||||||
verbose=True)
|
|
||||||
optimizer = paddle.optimizer.Adam(
|
|
||||||
learning_rate=lr_scheduler,
|
|
||||||
parameters=model.parameters(),
|
|
||||||
weight_decay=paddle.regularizer.L2Decay(
|
|
||||||
config["training"]["weight_decay"]))
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
self.logger.info("Setup model/criterion/optimizer/lr_scheduler!")
|
|
||||||
|
|
||||||
def setup_dataloader(self):
|
|
||||||
config = self.config["data"].copy()
|
|
||||||
train_dataset = DefinedDataset[config["dataset_type"]](
|
|
||||||
train_path=config["train_path"], **config["data_params"])
|
|
||||||
dev_dataset = DefinedDataset[config["dataset_type"]](
|
|
||||||
train_path=config["dev_path"], **config["data_params"])
|
|
||||||
|
|
||||||
self.train_loader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
num_workers=config["num_workers"],
|
|
||||||
batch_size=config["batch_size"])
|
|
||||||
self.valid_loader = DataLoader(
|
|
||||||
dev_dataset,
|
|
||||||
batch_size=config["batch_size"],
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
num_workers=config["num_workers"])
|
|
||||||
self.logger.info("Setup train/valid Dataloader!")
|
|
||||||
|
|
||||||
|
|
||||||
class Tester(Trainer):
|
|
||||||
def __init__(self, config, args):
|
|
||||||
super().__init__(config, args)
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
@paddle.no_grad()
|
|
||||||
def test(self):
|
|
||||||
self.logger.info(
|
|
||||||
f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
||||||
self.punc_list = []
|
|
||||||
for i in range(len(self.test_loader.dataset.id2punc)):
|
|
||||||
self.punc_list.append(self.test_loader.dataset.id2punc[i])
|
|
||||||
self.model.eval()
|
|
||||||
test_total_label = []
|
|
||||||
test_total_predict = []
|
|
||||||
with open(self.args.result_file, 'w') as fout:
|
|
||||||
for i, batch in enumerate(self.test_loader):
|
|
||||||
input, label = batch
|
|
||||||
label = paddle.reshape(label, shape=[-1])
|
|
||||||
y, logit = self.model(input)
|
|
||||||
pred = paddle.argmax(logit, axis=1)
|
|
||||||
test_total_label.extend(label.numpy().tolist())
|
|
||||||
test_total_predict.extend(pred.numpy().tolist())
|
|
||||||
|
|
||||||
# logging
|
|
||||||
msg = "Test: "
|
|
||||||
msg += "epoch: {}, ".format(self.epoch)
|
|
||||||
msg += "step: {}, ".format(self.iteration)
|
|
||||||
self.logger.info(msg)
|
|
||||||
t = classification_report(
|
|
||||||
test_total_label, test_total_predict, target_names=self.punc_list)
|
|
||||||
print(t)
|
|
||||||
t2 = self.evaluation(test_total_label, test_total_predict)
|
|
||||||
print(t2)
|
|
||||||
|
|
||||||
def evaluation(self, y_pred, y_test):
|
|
||||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
||||||
y_test, y_pred, average=None, labels=[1, 2, 3])
|
|
||||||
overall = precision_recall_fscore_support(
|
|
||||||
y_test, y_pred, average='macro', labels=[1, 2, 3])
|
|
||||||
result = pd.DataFrame(
|
|
||||||
np.array([precision, recall, f1]),
|
|
||||||
columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:],
|
|
||||||
index=['Precision', 'Recall', 'F1'])
|
|
||||||
result['OVERALL'] = overall[:3]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def run_test(self):
|
|
||||||
self.resume_or_scratch()
|
|
||||||
try:
|
|
||||||
self.test()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self.logger.info("Testing was aborted by keybord interrupt.")
|
|
||||||
exit(-1)
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
"""Setup the experiment.
|
|
||||||
"""
|
|
||||||
if self.args.ngpu > 0:
|
|
||||||
paddle.set_device('gpu')
|
|
||||||
else:
|
|
||||||
paddle.set_device('cpu')
|
|
||||||
self.setup_logger()
|
|
||||||
self.setup_output_dir()
|
|
||||||
self.setup_checkpointer()
|
|
||||||
|
|
||||||
self.setup_dataloader()
|
|
||||||
self.setup_model()
|
|
||||||
|
|
||||||
self.iteration = 0
|
|
||||||
self.epoch = 0
|
|
||||||
|
|
||||||
def setup_model(self):
|
|
||||||
config = self.config
|
|
||||||
model = DefinedClassifier[self.config["model_type"]](
|
|
||||||
**self.config["model_params"])
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.logger.info("Setup model!")
|
|
||||||
|
|
||||||
def setup_dataloader(self):
|
|
||||||
config = self.config["data"].copy()
|
|
||||||
|
|
||||||
test_dataset = DefinedDataset[config["dataset_type"]](
|
|
||||||
train_path=config["test_path"], **config["data_params"])
|
|
||||||
|
|
||||||
self.test_loader = DataLoader(
|
|
||||||
test_dataset,
|
|
||||||
batch_size=config["batch_size"],
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False)
|
|
||||||
self.logger.info("Setup test Dataloader!")
|
|
||||||
|
|
||||||
def setup_output_dir(self):
|
|
||||||
"""Create a directory used for output.
|
|
||||||
"""
|
|
||||||
# output dir
|
|
||||||
if self.args.output_dir:
|
|
||||||
output_dir = Path(self.args.output_dir).expanduser()
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
else:
|
|
||||||
output_dir = Path(
|
|
||||||
self.args.checkpoint_path).expanduser().parent.parent
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
self.output_dir = output_dir
|
|
||||||
|
|
||||||
def setup_logger(self):
|
|
||||||
LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
|
||||||
format_str = logging.Formatter(
|
|
||||||
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
|
||||||
)
|
|
||||||
logging.basicConfig(
|
|
||||||
filename=self.config["testing"]["log_path"],
|
|
||||||
level=logging.INFO,
|
|
||||||
format=LOG_FORMAT)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
self.logger.setLevel(logging.INFO)
|
|
||||||
sh = logging.StreamHandler()
|
|
||||||
sh.setFormatter(format_str)
|
|
||||||
self.logger.addHandler(sh)
|
|
||||||
|
|
||||||
self.logger.info('info')
|
|
@ -1,73 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
|
|
||||||
def default_argument_parser():
|
|
||||||
r"""A simple yet genral argument parser for experiments with t2s.
|
|
||||||
|
|
||||||
This is used in examples with t2s. And it is intended to be used by
|
|
||||||
other experiments with t2s. It requires a minimal set of command line
|
|
||||||
arguments to start a training script.
|
|
||||||
|
|
||||||
The ``--config`` and ``--opts`` are used for overwrite the deault
|
|
||||||
configuration.
|
|
||||||
|
|
||||||
The ``--data`` and ``--output`` specifies the data path and output path.
|
|
||||||
Resuming training from existing progress at the output directory is the
|
|
||||||
intended default behavior.
|
|
||||||
|
|
||||||
The ``--checkpoint_path`` specifies the checkpoint to load from.
|
|
||||||
|
|
||||||
The ``--ngpu`` specifies how to run the training.
|
|
||||||
|
|
||||||
|
|
||||||
See Also
|
|
||||||
--------
|
|
||||||
paddlespeech.t2s.training.experiment
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
argparse.ArgumentParser
|
|
||||||
the parser
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
# yapf: disable
|
|
||||||
# data and output
|
|
||||||
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
|
|
||||||
parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.")
|
|
||||||
# parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
|
|
||||||
parser.add_argument("--output_dir", metavar="OUTPUT_DIR", help="path to save checkpoint.")
|
|
||||||
parser.add_argument("--log_dir", metavar="LOG_DIR", help="path to save logs.")
|
|
||||||
|
|
||||||
# load from saved checkpoint
|
|
||||||
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
|
|
||||||
|
|
||||||
# save jit model to
|
|
||||||
parser.add_argument("--export_path", type=str, help="path of the jit model to save")
|
|
||||||
|
|
||||||
# save asr result to
|
|
||||||
parser.add_argument("--result_file", type=str, help="path of save the asr result")
|
|
||||||
|
|
||||||
# running
|
|
||||||
parser.add_argument("--ngpu", type=int, default=1, help="number of parallel processes to use. if ngpu=0, using cpu.")
|
|
||||||
|
|
||||||
# overwrite extra config and default config
|
|
||||||
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
|
|
||||||
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
|
||||||
parser.add_argument("--opts", type=str, default=[], nargs='+',
|
|
||||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
|
||||||
# yapd: enable
|
|
||||||
|
|
||||||
return parser
|
|
Loading…
Reference in new issue