diff --git a/README.md b/README.md
index 3970f79b..50dac64c 100644
--- a/README.md
+++ b/README.md
@@ -424,6 +424,30 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+**Punctuation Restoration**
+
+
+
+
+ Task |
+ Dataset |
+ Model Type |
+ Link |
+
+
+
+
+
+ Punctuation Restoration |
+ IWLST2012_zh |
+ Ernie Linear |
+
+ iwslt2012-punc0
+ |
+
+
+
+
## Documents
Normally, [Speech SoTA](https://paperswithcode.com/area/speech), [Audio SoTA](https://paperswithcode.com/area/audio) and [Music SoTA](https://paperswithcode.com/area/music) give you an overview of the hot academic topics in the related area. To focus on the tasks in PaddleSpeech, you will find the following guidelines are helpful to grasp the core ideas.
diff --git a/README_cn.md b/README_cn.md
index b47e9e61..14167864 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -415,6 +415,30 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
+**标点恢复**
+
+
+
+
+ 任务 |
+ 数据集 |
+ 模型种类 |
+ 链接 |
+
+
+
+
+
+ 标点恢复 |
+ IWLST2012_zh |
+ Ernie Linear |
+
+ iwslt2012-punc0
+ |
+
+
+
+
## 教程文档
对于 PaddleSpeech 的所关注的任务,以下指南有助于帮助开发者快速入门,了解语音相关核心思想。
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index 9db6a4c4..a10b2674 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -1,11 +1,10 @@
-
# Released Models
## Speech-to-Text Models
### Speech Recognition Model
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link
-:-------------:| :------------:| :-----: | -----: | :----------------- |:--------- | :---------- | :--------- | :-----------
+:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----:
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.080 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0)
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.064 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0)
[Conformer Online Aishell ASR1 Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz) | Aishell Dataset | Char-based | 283 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0594 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1)
@@ -17,22 +16,21 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
### Language Model based on NGram
Language Model | Training Data | Token-based | Size | Descriptions
-:-------------:| :------------:| :-----: | -----: | :-----------------
+:------------:| :------------:|:------------: | :------------: | :------------:
[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1;
About 1.85 billion n-grams;
'trie' binary with '-a 22 -q 8 -b 8'
[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4;
About 0.13 billion n-grams;
'probing' binary with default settings
[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning;
About 3.7 billion n-grams;
'probing' binary with default settings
### Speech Translation Models
-| Model | Training Data | Token-based | Size | Descriptions | BLEU | Example Link |
-| ------------------------------------------------------------ | ------------- | ----------- | ---- | ------------------------------------------------------------ | ----- | ------------------------------------------------------------ |
-| [Transformer FAT-ST MTL En-Zh](https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz) | Ted-En-Zh | Spm | | Encoder:Transformer, Decoder:Transformer,
Decoding method: Attention | 20.80 | [Transformer Ted-En-Zh ST1](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/ted_en_zh/st1) |
-
+| Model | Training Data | Token-based | Size | Descriptions | BLEU | Example Link |
+| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
+| [Transformer FAT-ST MTL En-Zh](https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz) | Ted-En-Zh| Spm| | Encoder:Transformer, Decoder:Transformer,
Decoding method: Attention | 20.80 | [Transformer Ted-En-Zh ST1](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/ted_en_zh/st1) |
## Text-to-Speech Models
### Acoustic Models
-Model Type | Dataset| Example Link | Pretrained Models|Static Models|Siize(static)
+Model Type | Dataset| Example Link | Pretrained Models|Static Models|Size (static)
:-------------:| :------------:| :-----: | :-----:| :-----:| :-----:
Tacotron2|LJSpeech|[tacotron2-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.3.zip)|||
TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)|||
@@ -44,8 +42,8 @@ FastSpeech2| LJSpeech |[fastspeech2-ljspeech](https://github.com/PaddlePaddle/Pa
FastSpeech2| VCTK |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/tts3)|[fastspeech2_nosil_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip)|||
### Vocoders
-Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size(static)
-:-------------:| :------------:| :-----: | :-----:| :-----:| :-----:
+Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size (static)
+:-----:| :-----:| :-----: | :-----:| :-----:| :-----:
WaveFlow| LJSpeech |[waveflow-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0)|[waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip)|||
Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip)|5.1MB|
Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)|||
@@ -69,10 +67,15 @@ Model Type | Dataset| Example Link | Pretrained Models
PANN | Audioset| [audioset_tagging_cnn](https://github.com/qiuqiangkong/audioset_tagging_cnn) | [panns_cnn6.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn6.pdparams), [panns_cnn10.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn10.pdparams), [panns_cnn14.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn14.pdparams)
PANN | ESC-50 |[pann-esc50]("./examples/esc50/cls0")|[esc50_cnn6.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn6.tar.gz), [esc50_cnn10.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn10.tar.gz), [esc50_cnn14.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn14.tar.gz)
+## Punctuation Restoration Models
+Model Type | Dataset| Example Link | Pretrained Models
+:-------------:| :------------:| :-----: | :-----:
+Ernie Linear | IWLST2012_zh |[iwslt2012_punc0](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/iwslt2012/punc0)|[ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip)
+
## Speech Recognition Model from paddle 1.8
-| Acoustic Model |Training Data| Token-based | Size | Descriptions | CER | WER | Hours of speech |
-| :--------------: | :--------------: | :--------------: | :--------------: | :--------------: | :--------------: | :--------------: | :--------------: |
+| Acoustic Model |Training Data| Token-based | Size | Descriptions | CER | WER | Hours of speech |
+| :-----:| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
| [Ds2 Offline Aishell model](https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz) | Aishell Dataset | Char-based | 234 MB | 2 Conv + 3 bidirectional GRU layers | 0.0804 | — | 151 h |
-| [Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz) | Librispeech Dataset | Word-based | 307 MB | 2 Conv + 3 bidirectional sharing weight RNN layers | — | 0.0685 | 960 h |
-| [Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz) | Baidu Internal English Dataset | Word-based | 273 MB | 2 Conv + 3 bidirectional GRU layers |— | 0.0541 | 8628 h |
+| [Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz) | Librispeech Dataset | Word-based | 307 MB | 2 Conv + 3 bidirectional sharing weight RNN layers | — | 0.0685 | 960 h |
+| [Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz) | Baidu Internal English Dataset | Word-based | 273 MB | 2 Conv + 3 bidirectional GRU layers |— | 0.0541 | 8628 h|
diff --git a/examples/iwslt2012/punc0/README.md b/examples/iwslt2012/punc0/README.md
index 38ef36fe..15ccea85 100644
--- a/examples/iwslt2012/punc0/README.md
+++ b/examples/iwslt2012/punc0/README.md
@@ -1,17 +1,28 @@
-# 中文实验例程
-## 测试数据:
-- IWLST2012中文:test2012
+# Punctuation Restoration with IWLST2012
+## Get Started
+### Data Preprocessing
+```bash
+./run.sh --stage 0 --stop-stage 0
+```
+### Model Training
+```bash
+./run.sh --stage 1 --stop-stage 1
+```
+### Testing
+```bash
+./run.sh --stage 2 --stop-stage 2
+```
+### Punctuation Restoration
+```bash
+./run.sh --stage 3 --stop-stage 3
+```
+## Pretrained Model
+The pretrained model can be downloaded here [ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip).
-## 运行代码
-- 运行 `./run.sh 0 0 conf/ernie_linear.yaml 1`
-
-## 实验结果:
-- ErnieLinear
- - 实验配置:conf/ernie_linear.yaml
- - 测试结果
-
- | | COMMA | PERIOD | QUESTION | OVERALL |
- |-----------|-----------|-----------|-----------|--------- |
- |Precision | 0.471831 | 0.497679 | 0.830189 | 0.599899 |
- |Recall | 0.583172 | 0.641148 | 0.846154 | 0.690158 |
- |F1 | 0.521626 | 0.560376 | 0.838095 | 0.640033 |
+### Test Result
+- Ernie Linear
+ | |COMMA | PERIOD | QUESTION | OVERALL|
+ |:-----:|:-----:|:-----:|:-----:|:-----:|
+ |Precision |0.510955 |0.526462 |0.820755 |0.619391|
+ |Recall |0.517433 |0.564179 |0.861386 |0.647666|
+ |F1 |0.514173 |0.544669 |0.840580 |0.633141|
diff --git a/examples/iwslt2012/punc0/conf/default.yaml b/examples/iwslt2012/punc0/conf/default.yaml
new file mode 100644
index 00000000..74ced993
--- /dev/null
+++ b/examples/iwslt2012/punc0/conf/default.yaml
@@ -0,0 +1,44 @@
+###########################################################
+# DATA SETTING #
+###########################################################
+dataset_type: Ernie
+train_path: data/iwslt2012_zh/train.txt
+dev_path: data/iwslt2012_zh/dev.txt
+test_path: data/iwslt2012_zh/test.txt
+batch_size: 64
+num_workers: 2
+data_params:
+ pretrained_token: ernie-1.0
+ punc_path: data/iwslt2012_zh/punc_vocab
+ seq_len: 100
+
+
+###########################################################
+# MODEL SETTING #
+###########################################################
+model_type: ErnieLinear
+model:
+ pretrained_token: ernie-1.0
+ num_classes: 4
+
+###########################################################
+# OPTIMIZER SETTING #
+###########################################################
+optimizer_params:
+ weight_decay: 1.0e-6 # weight decay coefficient.
+
+scheduler_params:
+ learning_rate: 1.0e-5 # learning rate.
+ gamma: 1.0 # scheduler gamma.
+
+###########################################################
+# TRAINING SETTING #
+###########################################################
+max_epoch: 20
+num_snapshots: 5
+
+###########################################################
+# OTHER SETTING #
+###########################################################
+num_snapshots: 10 # max number of snapshots to keep while training
+seed: 42 # random seed for paddle, random, and np.random
diff --git a/examples/iwslt2012/punc0/conf/ernie_linear.yaml b/examples/iwslt2012/punc0/conf/ernie_linear.yaml
deleted file mode 100644
index bf892110..00000000
--- a/examples/iwslt2012/punc0/conf/ernie_linear.yaml
+++ /dev/null
@@ -1,36 +0,0 @@
-data:
- dataset_type: Ernie
- train_path: data/iwslt2012_zh/train.txt
- dev_path: data/iwslt2012_zh/dev.txt
- test_path: data/iwslt2012_zh/test.txt
- data_params:
- pretrained_token: ernie-1.0
- punc_path: data/iwslt2012_zh/punc_vocab
- seq_len: 100
- batch_size: 64
- sortagrad: True
- shuffle_method: batch_shuffle
- num_workers: 0
-
-checkpoint:
- kbest_n: 5
- latest_n: 10
- metric_type: F1
-
-model_type: ErnieLinear
-
-model_params:
- pretrained_token: ernie-1.0
- num_classes: 4
-
-training:
- n_epoch: 20
- lr: !!float 1e-5
- lr_decay: 1.0
- weight_decay: !!float 1e-06
- global_grad_clip: 5.0
- log_interval: 10
- log_path: log/train_ernie_linear.log
-
-testing:
- log_path: log/test_ernie_linear.log
diff --git a/examples/iwslt2012/punc0/local/avg.sh b/examples/iwslt2012/punc0/local/avg.sh
deleted file mode 100644
index b8c14c66..00000000
--- a/examples/iwslt2012/punc0/local/avg.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-#! /usr/bin/env bash
-
-if [ $# != 2 ]; then
- echo "usage: ${0} ckpt_dir avg_num"
- exit -1
-fi
-
-ckpt_dir=${1}
-average_num=${2}
-decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
-
-python3 -u ${BIN_DIR}/avg_model.py \
---dst_model ${decode_checkpoint} \
---ckpt_dir ${ckpt_dir} \
---num ${average_num} \
---val_best
-
-if [ $? -ne 0 ]; then
- echo "Failed in avg ckpt!"
- exit 1
-fi
-
-exit 0
\ No newline at end of file
diff --git a/examples/iwslt2012/punc0/local/data.sh b/examples/iwslt2012/punc0/local/data.sh
old mode 100644
new mode 100755
diff --git a/examples/iwslt2012/punc0/local/punc_restore.sh b/examples/iwslt2012/punc0/local/punc_restore.sh
new file mode 100755
index 00000000..30a4f12f
--- /dev/null
+++ b/examples/iwslt2012/punc0/local/punc_restore.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+ckpt_name=$3
+text=$4
+ckpt_prefix=${ckpt_name%.*}
+
+python3 ${BIN_DIR}/punc_restore.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --text=${text}
diff --git a/examples/iwslt2012/punc0/local/test.sh b/examples/iwslt2012/punc0/local/test.sh
old mode 100644
new mode 100755
index ee022462..94e508b5
--- a/examples/iwslt2012/punc0/local/test.sh
+++ b/examples/iwslt2012/punc0/local/test.sh
@@ -1,26 +1,11 @@
-
#!/bin/bash
-if [ $# != 2 ];then
- echo "usage: ${0} config_path ckpt_path_prefix"
- exit -1
-fi
-
-ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
-echo "using $ngpu gpus..."
-
config_path=$1
-ckpt_prefix=$2
-
-python3 -u ${BIN_DIR}/test.py \
---ngpu 1 \
---config ${config_path} \
---result_file ${ckpt_prefix}.rsl \
---checkpoint_path ${ckpt_prefix}
+train_output_path=$2
+ckpt_name=$3
-if [ $? -ne 0 ]; then
- echo "Failed in evaluation!"
- exit 1
-fi
+ckpt_prefix=${ckpt_name%.*}
-exit 0
+python3 ${BIN_DIR}/test.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name}
diff --git a/examples/iwslt2012/punc0/local/train.sh b/examples/iwslt2012/punc0/local/train.sh
old mode 100644
new mode 100755
index 9fabb8f7..85227eac
--- a/examples/iwslt2012/punc0/local/train.sh
+++ b/examples/iwslt2012/punc0/local/train.sh
@@ -1,28 +1,9 @@
#!/bin/bash
-if [ $# != 3 ];then
- echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name log_dir"
- exit -1
-fi
-
-ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
-echo "using $ngpu gpus..."
-
config_path=$1
-ckpt_name=$2
-log_dir=$3
-
-mkdir -p exp
-
-python3 -u ${BIN_DIR}/train.py \
---ngpu ${ngpu} \
---config ${config_path} \
---output_dir exp/${ckpt_name} \
---log_dir ${log_dir}
-
-if [ $? -ne 0 ]; then
- echo "Failed in training!"
- exit 1
-fi
+train_output_path=$2
-exit 0
+python3 ${BIN_DIR}/train.py \
+ --config=${config_path} \
+ --output-dir=${train_output_path} \
+ --ngpu=1
diff --git a/examples/iwslt2012/punc0/path.sh b/examples/iwslt2012/punc0/path.sh
old mode 100644
new mode 100755
index 8f67f9c9..da790261
--- a/examples/iwslt2012/punc0/path.sh
+++ b/examples/iwslt2012/punc0/path.sh
@@ -10,5 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
-MODEL=$1
+MODEL=ernie_linear
export BIN_DIR=${MAIN_ROOT}/paddlespeech/text/exps/${MODEL}
diff --git a/examples/iwslt2012/punc0/run.sh b/examples/iwslt2012/punc0/run.sh
index 8d786a19..0c14eb7e 100755
--- a/examples/iwslt2012/punc0/run.sh
+++ b/examples/iwslt2012/punc0/run.sh
@@ -1,40 +1,35 @@
#!/bin/bash
set -e
+source path.sh
-if [ $# -ne 4 ]; then
- echo "usage: bash ./run.sh stage gpu train_config avg_num"
- echo "eg: bash ./run.sh 1 0 train_config 1"
- exit -1
-fi
-
-stage=$1
+gpus=0,1
+stage=0
stop_stage=100
-gpus=$2
-conf_path=$3
-avg_num=$4
-avg_ckpt=avg_${avg_num}
-ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
-log_dir=log
-source path.sh ${ckpt}
+conf_path=conf/default.yaml
+train_output_path=exp/default
+ckpt_name=snapshot_iter_12840.pdz
+text=今天的天气真不错啊你下午有空吗我想约你一起去吃饭
+# with the following command, you can choose the stage range you want to run
+# such as `./run.sh --stage 0 --stop-stage 0`
+# this can not be mixed use with `$1`, `$2` ...
+source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
-if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
- bash ./local/data.sh
+ ./local/data.sh
fi
-if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
- # train model, all `ckpt` under `exp` dir
- CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${conf_path} ${ckpt} ${log_dir}
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # train model, all `ckpt` under `train_output_path/checkpoints/` dir
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
-if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- # avg n best model
- bash ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- # test ckpt avg_n
- CUDA_VISIBLE_DEVICES=${gpus} bash ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
-fi
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/punc_restore.sh ${conf_path} ${train_output_path} ${ckpt_name} ${text}|| exit -1
+fi
\ No newline at end of file
diff --git a/paddlespeech/t2s/exps/fastspeech2/gen_gta_mel.py b/paddlespeech/t2s/exps/fastspeech2/gen_gta_mel.py
index 8a9ef370..fa46fd55 100644
--- a/paddlespeech/t2s/exps/fastspeech2/gen_gta_mel.py
+++ b/paddlespeech/t2s/exps/fastspeech2/gen_gta_mel.py
@@ -132,7 +132,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
def str2bool(str):
return True if str.lower() == 'true' else False
diff --git a/paddlespeech/t2s/exps/fastspeech2/train.py b/paddlespeech/t2s/exps/fastspeech2/train.py
index fafded6f..1dfa575a 100644
--- a/paddlespeech/t2s/exps/fastspeech2/train.py
+++ b/paddlespeech/t2s/exps/fastspeech2/train.py
@@ -174,7 +174,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
diff --git a/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py b/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py
index 3bc11a60..9ac6cbd3 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py
@@ -250,7 +250,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py b/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py
index a44d2d3c..3d0ff7d3 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py
@@ -239,7 +239,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py
index ca2e3f55..f5affb50 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py
@@ -93,7 +93,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py
index 98b0ed71..a7881d6b 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py
@@ -216,7 +216,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
benchmark_group = parser.add_argument_group(
'benchmark', 'arguments related to benchmark.')
diff --git a/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py
index bc746467..36e4d645 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py
@@ -232,7 +232,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/gan_vocoder/synthesize.py b/paddlespeech/t2s/exps/gan_vocoder/synthesize.py
index 6f4dc92d..c60b9add 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/synthesize.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/synthesize.py
@@ -42,7 +42,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py b/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
index 2854d055..cb742c59 100644
--- a/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
+++ b/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
@@ -173,7 +173,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
- parser.add_argument("--verbose", type=int, default=1, help="verbose")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize.py b/paddlespeech/t2s/exps/transformer_tts/synthesize.py
index 666c3b72..7b6b1873 100644
--- a/paddlespeech/t2s/exps/transformer_tts/synthesize.py
+++ b/paddlespeech/t2s/exps/transformer_tts/synthesize.py
@@ -118,7 +118,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
index ba197f43..0cd7d224 100644
--- a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
+++ b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
@@ -137,7 +137,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/transformer_tts/train.py b/paddlespeech/t2s/exps/transformer_tts/train.py
index 163339f4..8695c06a 100644
--- a/paddlespeech/t2s/exps/transformer_tts/train.py
+++ b/paddlespeech/t2s/exps/transformer_tts/train.py
@@ -165,7 +165,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
diff --git a/paddlespeech/text/utils/__init__.py b/paddlespeech/text/exps/__init__.py
similarity index 89%
rename from paddlespeech/text/utils/__init__.py
rename to paddlespeech/text/exps/__init__.py
index 185a92b8..abf198b9 100644
--- a/paddlespeech/text/utils/__init__.py
+++ b/paddlespeech/text/exps/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/text/training/__init__.py b/paddlespeech/text/exps/ernie_linear/__init__.py
similarity index 89%
rename from paddlespeech/text/training/__init__.py
rename to paddlespeech/text/exps/ernie_linear/__init__.py
index 185a92b8..abf198b9 100644
--- a/paddlespeech/text/training/__init__.py
+++ b/paddlespeech/text/exps/ernie_linear/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/text/exps/ernie_linear/punc_restore.py b/paddlespeech/text/exps/ernie_linear/punc_restore.py
new file mode 100644
index 00000000..2cb4d071
--- /dev/null
+++ b/paddlespeech/text/exps/ernie_linear/punc_restore.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import re
+
+import paddle
+import yaml
+from paddlenlp.transformers import ErnieTokenizer
+from yacs.config import CfgNode
+
+from paddlespeech.text.models.ernie_linear import ErnieLinear
+
+DefinedClassifier = {
+ 'ErnieLinear': ErnieLinear,
+}
+
+tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
+
+
+def _clean_text(text, punc_list):
+ text = text.lower()
+ text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
+ text = re.sub(f'[{"".join([p for p in punc_list][1:])}]', '', text)
+ return text
+
+
+def preprocess(text, punc_list):
+ clean_text = _clean_text(text, punc_list)
+ assert len(clean_text) > 0, f'Invalid input string: {text}'
+ tokenized_input = tokenizer(
+ list(clean_text), return_length=True, is_split_into_words=True)
+ _inputs = dict()
+ _inputs['input_ids'] = tokenized_input['input_ids']
+ _inputs['seg_ids'] = tokenized_input['token_type_ids']
+ _inputs['seq_len'] = tokenized_input['seq_len']
+ return _inputs
+
+
+def test(args):
+ with open(args.config) as f:
+ config = CfgNode(yaml.safe_load(f))
+ print("========Args========")
+ print(yaml.safe_dump(vars(args)))
+ print("========Config========")
+ print(config)
+
+ punc_list = []
+ with open(config["data_params"]["punc_path"], 'r') as f:
+ for line in f:
+ punc_list.append(line.strip())
+
+ model = DefinedClassifier[config["model_type"]](**config["model"])
+ state_dict = paddle.load(args.checkpoint)
+ model.set_state_dict(state_dict["main_params"])
+ model.eval()
+ _inputs = preprocess(args.text, punc_list)
+ seq_len = _inputs['seq_len']
+ input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
+ seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)
+ logits, _ = model(input_ids, seg_ids)
+ preds = paddle.argmax(logits, axis=-1).squeeze(0)
+ tokens = tokenizer.convert_ids_to_tokens(
+ _inputs['input_ids'][1:seq_len - 1])
+ labels = preds[1:seq_len - 1].tolist()
+ assert len(tokens) == len(labels)
+ # add 0 for non punc
+ punc_list = [0] + punc_list
+ text = ''
+ for t, l in zip(tokens, labels):
+ text += t
+ if l != 0: # Non punc.
+ text += punc_list[l]
+ print("Punctuation Restoration Result:", text)
+ return text
+
+
+def main():
+ # parse args and config and redirect to train_sp
+ parser = argparse.ArgumentParser(description="Run Punctuation Restoration.")
+ parser.add_argument("--config", type=str, help="ErnieLinear config file.")
+ parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
+ parser.add_argument("--text", type=str, help="raw text to be restored.")
+ parser.add_argument(
+ "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
+
+ args = parser.parse_args()
+
+ if args.ngpu == 0:
+ paddle.set_device("cpu")
+ elif args.ngpu > 0:
+ paddle.set_device("gpu")
+ else:
+ print("ngpu should >= 0 !")
+
+ test(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/text/exps/ernie_linear/test.py b/paddlespeech/text/exps/ernie_linear/test.py
index 3cd507fb..4302a1a3 100644
--- a/paddlespeech/text/exps/ernie_linear/test.py
+++ b/paddlespeech/text/exps/ernie_linear/test.py
@@ -11,36 +11,110 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Evaluation for model."""
+import argparse
+
+import numpy as np
+import paddle
+import pandas as pd
import yaml
+from paddle import nn
+from paddle.io import DataLoader
+from sklearn.metrics import classification_report
+from sklearn.metrics import precision_recall_fscore_support
+from yacs.config import CfgNode
-from paddlespeech.s2t.utils.utility import print_arguments
-from paddlespeech.text.training.trainer import Tester
-from paddlespeech.text.utils.default_parser import default_argument_parser
+from paddlespeech.text.models.ernie_linear import ErnieLinear
+from paddlespeech.text.models.ernie_linear import PuncDataset
+from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
+DefinedClassifier = {
+ 'ErnieLinear': ErnieLinear,
+}
-def main_sp(config, args):
- exp = Tester(config, args)
- exp.setup()
- exp.run_test()
+DefinedLoss = {
+ "ce": nn.CrossEntropyLoss,
+}
+DefinedDataset = {
+ 'Punc': PuncDataset,
+ 'Ernie': PuncDatasetFromErnieTokenizer,
+}
-def main(config, args):
- main_sp(config, args)
+def evaluation(y_pred, y_test):
+ precision, recall, f1, _ = precision_recall_fscore_support(
+ y_test, y_pred, average=None, labels=[1, 2, 3])
+ overall = precision_recall_fscore_support(
+ y_test, y_pred, average='macro', labels=[1, 2, 3])
+ result = pd.DataFrame(
+ np.array([precision, recall, f1]),
+ columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:],
+ index=['Precision', 'Recall', 'F1'])
+ result['OVERALL'] = overall[:3]
+ return result
+
+
+def test(args):
+ with open(args.config) as f:
+ config = CfgNode(yaml.safe_load(f))
+ print("========Args========")
+ print(yaml.safe_dump(vars(args)))
+ print("========Config========")
+ print(config)
+
+ test_dataset = DefinedDataset[config["dataset_type"]](
+ train_path=config["test_path"], **config["data_params"])
+ test_loader = DataLoader(
+ test_dataset,
+ batch_size=config.batch_size,
+ shuffle=False,
+ drop_last=False)
+ model = DefinedClassifier[config["model_type"]](**config["model"])
+ state_dict = paddle.load(args.checkpoint)
+ model.set_state_dict(state_dict["main_params"])
+ model.eval()
+
+ punc_list = []
+ for i in range(len(test_loader.dataset.id2punc)):
+ punc_list.append(test_loader.dataset.id2punc[i])
+
+ test_total_label = []
+ test_total_predict = []
+
+ for i, batch in enumerate(test_loader):
+ input, label = batch
+ label = paddle.reshape(label, shape=[-1])
+ y, logit = model(input)
+ pred = paddle.argmax(logit, axis=1)
+ test_total_label.extend(label.numpy().tolist())
+ test_total_predict.extend(pred.numpy().tolist())
+ t = classification_report(
+ test_total_label, test_total_predict, target_names=punc_list)
+ print(t)
+ t2 = evaluation(test_total_label, test_total_predict)
+ print('=========================================================')
+ print(t2)
+
+
+def main():
+ # parse args and config and redirect to train_sp
+ parser = argparse.ArgumentParser(description="Test a ErnieLinear model.")
+ parser.add_argument("--config", type=str, help="ErnieLinear config file.")
+ parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
+ parser.add_argument(
+ "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
-if __name__ == "__main__":
- parser = default_argument_parser()
args = parser.parse_args()
- print_arguments(args, globals())
- # https://yaml.org/type/float.html
- with open(args.config, "r") as f:
- config = yaml.load(f, Loader=yaml.FullLoader)
+ if args.ngpu == 0:
+ paddle.set_device("cpu")
+ elif args.ngpu > 0:
+ paddle.set_device("gpu")
+ else:
+ print("ngpu should >= 0 !")
+
+ test(args)
- print(config)
- if args.dump_config:
- with open(args.dump_config, 'w') as f:
- print(config, file=f)
- main(config, args)
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/text/exps/ernie_linear/train.py b/paddlespeech/text/exps/ernie_linear/train.py
index 09071438..0d730d66 100644
--- a/paddlespeech/text/exps/ernie_linear/train.py
+++ b/paddlespeech/text/exps/ernie_linear/train.py
@@ -11,40 +11,163 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Trainer for punctuation_restoration task."""
+import argparse
+import logging
+import os
+import shutil
+from pathlib import Path
+
+import paddle
import yaml
+from paddle import DataParallel
from paddle import distributed as dist
+from paddle import nn
+from paddle.io import DataLoader
+from paddle.optimizer import Adam
+from paddle.optimizer.lr import ExponentialDecay
+from yacs.config import CfgNode
-from paddlespeech.s2t.utils.utility import print_arguments
-from paddlespeech.text.training.trainer import Trainer
-from paddlespeech.text.utils.default_parser import default_argument_parser
+from paddlespeech.t2s.training.extensions.snapshot import Snapshot
+from paddlespeech.t2s.training.extensions.visualizer import VisualDL
+from paddlespeech.t2s.training.seeding import seed_everything
+from paddlespeech.t2s.training.trainer import Trainer
+from paddlespeech.text.models.ernie_linear import ErnieLinear
+from paddlespeech.text.models.ernie_linear import ErnieLinearEvaluator
+from paddlespeech.text.models.ernie_linear import ErnieLinearUpdater
+from paddlespeech.text.models.ernie_linear import PuncDataset
+from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
+DefinedClassifier = {
+ 'ErnieLinear': ErnieLinear,
+}
-def main_sp(config, args):
- exp = Trainer(config, args)
- exp.setup()
- exp.run()
+DefinedLoss = {
+ "ce": nn.CrossEntropyLoss,
+}
+DefinedDataset = {
+ 'Punc': PuncDataset,
+ 'Ernie': PuncDatasetFromErnieTokenizer,
+}
-def main(config, args):
- if args.ngpu > 1:
- dist.spawn(main_sp, args=(config, args), nprocs=args.ngpu)
+
+def train_sp(args, config):
+ # decides device type and whether to run in parallel
+ # setup running environment correctly
+ if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
+ paddle.set_device("cpu")
else:
- main_sp(config, args)
+ paddle.set_device("gpu")
+ world_size = paddle.distributed.get_world_size()
+ if world_size > 1:
+ paddle.distributed.init_parallel_env()
+ # set the random seed, it is a must for multiprocess training
+ seed_everything(config.seed)
+
+ print(
+ f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
+ )
+ # dataloader has been too verbose
+ logging.getLogger("DataLoader").disabled = True
+ train_dataset = DefinedDataset[config["dataset_type"]](
+ train_path=config["train_path"], **config["data_params"])
+ dev_dataset = DefinedDataset[config["dataset_type"]](
+ train_path=config["dev_path"], **config["data_params"])
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=True,
+ num_workers=config.num_workers,
+ batch_size=config.batch_size)
+
+ dev_dataloader = DataLoader(
+ dev_dataset,
+ batch_size=config.batch_size,
+ shuffle=False,
+ drop_last=False,
+ num_workers=config.num_workers)
+
+ print("dataloaders done!")
+
+ model = DefinedClassifier[config["model_type"]](**config["model"])
+
+ if world_size > 1:
+ model = DataParallel(model)
+ print("model done!")
+
+ criterion = DefinedLoss[config["loss_type"]](
+ **config["loss"]) if "loss_type" in config else DefinedLoss["ce"]()
+
+ print("criterions done!")
+
+ lr_schedule = ExponentialDecay(**config["scheduler_params"])
+ optimizer = Adam(
+ learning_rate=lr_schedule,
+ parameters=model.parameters(),
+ weight_decay=paddle.regularizer.L2Decay(
+ config["optimizer_params"]["weight_decay"]))
+
+ print("optimizer done!")
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ if dist.get_rank() == 0:
+ config_name = args.config.split("/")[-1]
+ # copy conf to output_dir
+ shutil.copyfile(args.config, output_dir / config_name)
+
+ updater = ErnieLinearUpdater(
+ model=model,
+ criterion=criterion,
+ scheduler=lr_schedule,
+ optimizer=optimizer,
+ dataloader=train_dataloader,
+ output_dir=output_dir)
+
+ trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
+
+ evaluator = ErnieLinearEvaluator(
+ model=model,
+ criterion=criterion,
+ dataloader=dev_dataloader,
+ output_dir=output_dir)
+
+ if dist.get_rank() == 0:
+ trainer.extend(evaluator, trigger=(1, "epoch"))
+ trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
+ trainer.extend(
+ Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
+ # print(trainer.extensions)
+ trainer.run()
+
+
+def main():
+ # parse args and config and redirect to train_sp
+ parser = argparse.ArgumentParser(description="Train a ErnieLinear model.")
+ parser.add_argument("--config", type=str, help="ErnieLinear config file.")
+ parser.add_argument("--output-dir", type=str, help="output dir.")
+ parser.add_argument(
+ "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
-if __name__ == "__main__":
- parser = default_argument_parser()
args = parser.parse_args()
- print_arguments(args, globals())
- # https://yaml.org/type/float.html
- with open(args.config, "r") as f:
- config = yaml.load(f, Loader=yaml.FullLoader)
+ with open(args.config) as f:
+ config = CfgNode(yaml.safe_load(f))
+ print("========Args========")
+ print(yaml.safe_dump(vars(args)))
+ print("========Config========")
print(config)
- if args.dump_config:
- with open(args.dump_config, 'w') as f:
- print(config, file=f)
+ print(
+ f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
+ )
+
+ # dispatch
+ if args.ngpu > 1:
+ dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
+ else:
+ train_sp(args, config)
+
- main(config, args)
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/text/models/ernie_linear/__init__.py b/paddlespeech/text/models/ernie_linear/__init__.py
index 93453ce7..0a10a6eb 100644
--- a/paddlespeech/text/models/ernie_linear/__init__.py
+++ b/paddlespeech/text/models/ernie_linear/__init__.py
@@ -11,4 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from .model import ErnieLinear
+from .dataset import *
+from .ernie_linear import *
+from .ernie_linear_updater import *
diff --git a/paddlespeech/text/models/ernie_linear/dataset.py b/paddlespeech/text/models/ernie_linear/dataset.py
index 086e91bb..64c8d0bd 100644
--- a/paddlespeech/text/models/ernie_linear/dataset.py
+++ b/paddlespeech/text/models/ernie_linear/dataset.py
@@ -99,10 +99,8 @@ class PuncDatasetFromErnieTokenizer(Dataset):
self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
self.paddingID = self.tokenizer.pad_token_id
self.seq_len = seq_len
-
self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "])
self.id2punc = {k: v for (v, k) in self.punc2id.items()}
-
tmp_seqs = open(train_path, encoding='utf-8').readlines()
self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()]
self.preprocess(self.txt_seqs)
@@ -125,6 +123,7 @@ class PuncDatasetFromErnieTokenizer(Dataset):
input_data = []
label = []
count = 0
+ print("Preprocessing in PuncDatasetFromErnieTokenizer...")
for i in range(len(txt_seqs) - 1):
word = txt_seqs[i]
punc = txt_seqs[i + 1]
diff --git a/paddlespeech/text/models/ernie_linear/model.py b/paddlespeech/text/models/ernie_linear/ernie_linear.py
similarity index 100%
rename from paddlespeech/text/models/ernie_linear/model.py
rename to paddlespeech/text/models/ernie_linear/ernie_linear.py
diff --git a/paddlespeech/text/models/ernie_linear/ernie_linear_updater.py b/paddlespeech/text/models/ernie_linear/ernie_linear_updater.py
new file mode 100644
index 00000000..8b3d7410
--- /dev/null
+++ b/paddlespeech/text/models/ernie_linear/ernie_linear_updater.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import paddle
+from paddle import distributed as dist
+from paddle.io import DataLoader
+from paddle.nn import Layer
+from paddle.optimizer import Optimizer
+from paddle.optimizer.lr import LRScheduler
+from sklearn.metrics import f1_score
+
+from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
+from paddlespeech.t2s.training.reporter import report
+from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
+logging.basicConfig(
+ format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
+ datefmt='[%Y-%m-%d %H:%M:%S]')
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class ErnieLinearUpdater(StandardUpdater):
+ def __init__(self,
+ model: Layer,
+ criterion: Layer,
+ scheduler: LRScheduler,
+ optimizer: Optimizer,
+ dataloader: DataLoader,
+ output_dir=None):
+ super().__init__(model, optimizer, dataloader, init_state=None)
+ self.model = model
+ self.dataloader = dataloader
+
+ self.criterion = criterion
+ self.scheduler = scheduler
+ self.optimizer = optimizer
+
+ log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
+ self.filehandler = logging.FileHandler(str(log_file))
+ logger.addHandler(self.filehandler)
+ self.logger = logger
+ self.msg = ""
+
+ def update_core(self, batch):
+ self.msg = "Rank: {}, ".format(dist.get_rank())
+ losses_dict = {}
+
+ input, label = batch
+ label = paddle.reshape(label, shape=[-1])
+ y, logit = self.model(input)
+ pred = paddle.argmax(logit, axis=1)
+
+ loss = self.criterion(y, label)
+
+ self.optimizer.clear_grad()
+ loss.backward()
+
+ self.optimizer.step()
+ self.scheduler.step()
+
+ F1_score = f1_score(
+ label.numpy().tolist(), pred.numpy().tolist(), average="macro")
+
+ report("train/loss", float(loss))
+ losses_dict["loss"] = float(loss)
+ report("train/F1_score", float(F1_score))
+ losses_dict["F1_score"] = float(F1_score)
+
+ self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
+ for k, v in losses_dict.items())
+
+
+class ErnieLinearEvaluator(StandardEvaluator):
+ def __init__(self,
+ model: Layer,
+ criterion: Layer,
+ dataloader: DataLoader,
+ output_dir=None):
+ super().__init__(model, dataloader)
+ self.model = model
+ self.criterion = criterion
+ self.dataloader = dataloader
+
+ log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
+ self.filehandler = logging.FileHandler(str(log_file))
+ logger.addHandler(self.filehandler)
+ self.logger = logger
+ self.msg = ""
+
+ def evaluate_core(self, batch):
+ self.msg = "Evaluate: "
+ losses_dict = {}
+
+ input, label = batch
+ label = paddle.reshape(label, shape=[-1])
+ y, logit = self.model(input)
+ pred = paddle.argmax(logit, axis=1)
+
+ loss = self.criterion(y, label)
+
+ F1_score = f1_score(
+ label.numpy().tolist(), pred.numpy().tolist(), average="macro")
+
+ report("eval/loss", float(loss))
+ losses_dict["loss"] = float(loss)
+ report("eval/F1_score", float(F1_score))
+ losses_dict["F1_score"] = float(F1_score)
+
+ self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
+ for k, v in losses_dict.items())
+ self.logger.info(self.msg)
diff --git a/paddlespeech/text/training/trainer.py b/paddlespeech/text/training/trainer.py
deleted file mode 100644
index b5e6a563..00000000
--- a/paddlespeech/text/training/trainer.py
+++ /dev/null
@@ -1,524 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-import time
-from collections import defaultdict
-from pathlib import Path
-
-import numpy as np
-import paddle
-import paddle.nn as nn
-import pandas as pd
-from paddle import distributed as dist
-from paddle.io import DataLoader
-from sklearn.metrics import classification_report
-from sklearn.metrics import f1_score
-from sklearn.metrics import precision_recall_fscore_support
-
-from ...s2t.utils import layer_tools
-from ...s2t.utils import mp_tools
-from ...s2t.utils.checkpoint import Checkpoint
-from ...text.models import ErnieLinear
-from ...text.models.ernie_linear.dataset import PuncDataset
-from ...text.models.ernie_linear.dataset import PuncDatasetFromErnieTokenizer
-
-__all__ = ["Trainer", "Tester"]
-
-DefinedClassifier = {
- 'ErnieLinear': ErnieLinear,
-}
-
-DefinedLoss = {
- "ce": nn.CrossEntropyLoss,
-}
-
-DefinedDataset = {
- 'Punc': PuncDataset,
- 'Ernie': PuncDatasetFromErnieTokenizer,
-}
-
-
-class Trainer():
- def __init__(self, config, args):
- self.config = config
- self.args = args
- self.optimizer = None
- self.output_dir = None
- self.log_dir = None
- self.checkpoint_dir = None
- self.iteration = 0
- self.epoch = 0
-
- def setup(self):
- """Setup the experiment.
- """
- self.setup_log_dir()
- self.setup_logger()
- if self.args.ngpu > 0:
- paddle.set_device('gpu')
- else:
- paddle.set_device('cpu')
- if self.parallel:
- self.init_parallel()
-
- self.setup_output_dir()
- self.dump_config()
- self.setup_checkpointer()
-
- self.setup_model()
-
- self.setup_dataloader()
-
- self.iteration = 0
- self.epoch = 1
-
- @property
- def parallel(self):
- """A flag indicating whether the experiment should run with
- multiprocessing.
- """
- return self.args.ngpu > 1
-
- def init_parallel(self):
- """Init environment for multiprocess training.
- """
- dist.init_parallel_env()
-
- @mp_tools.rank_zero_only
- def save(self, tag=None, infos: dict=None):
- """Save checkpoint (model parameters and optimizer states).
-
- Args:
- tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
- infos (dict, optional): meta data to save. Defaults to None.
- """
-
- infos = infos if infos else dict()
- infos.update({
- "step": self.iteration,
- "epoch": self.epoch,
- "lr": self.optimizer.get_lr()
- })
- self.checkpointer.save_parameters(self.checkpoint_dir, self.iteration
- if tag is None else tag, self.model,
- self.optimizer, infos)
-
- def resume_or_scratch(self):
- """Resume from latest checkpoint at checkpoints in the output
- directory or load a specified checkpoint.
-
- If ``args.checkpoint_path`` is not None, load the checkpoint, else
- resume training.
- """
- scratch = None
- infos = self.checkpointer.load_parameters(
- self.model,
- self.optimizer,
- checkpoint_dir=self.checkpoint_dir,
- checkpoint_path=self.args.checkpoint_path)
- if infos:
- # restore from ckpt
- self.iteration = infos["step"]
- self.epoch = infos["epoch"]
- scratch = False
- else:
- self.iteration = 0
- self.epoch = 0
- scratch = True
-
- return scratch
-
- def new_epoch(self):
- """Reset the train loader seed and increment `epoch`.
- """
- self.epoch += 1
- if self.parallel:
- self.train_loader.batch_sampler.set_epoch(self.epoch)
-
- def train(self):
- """The training process control by epoch."""
- from_scratch = self.resume_or_scratch()
-
- if from_scratch:
- # save init model, i.e. 0 epoch
- self.save(tag="init")
-
- self.lr_scheduler.step(self.iteration)
- if self.parallel:
- self.train_loader.batch_sampler.set_epoch(self.epoch)
-
- self.logger.info(
- f"Train Total Examples: {len(self.train_loader.dataset)}")
- self.punc_list = []
- for i in range(len(self.train_loader.dataset.id2punc)):
- self.punc_list.append(self.train_loader.dataset.id2punc[i])
- while self.epoch < self.config["training"]["n_epoch"]:
- self.model.train()
- self.total_label_train = []
- self.total_predict_train = []
- try:
- data_start_time = time.time()
- for batch_index, batch in enumerate(self.train_loader):
- dataload_time = time.time() - data_start_time
- msg = "Train: Rank: {}, ".format(dist.get_rank())
- msg += "epoch: {}, ".format(self.epoch)
- msg += "step: {}, ".format(self.iteration)
- msg += "batch : {}/{}, ".format(batch_index + 1,
- len(self.train_loader))
- msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
- msg += "data time: {:>.3f}s, ".format(dataload_time)
- self.train_batch(batch_index, batch, msg)
- data_start_time = time.time()
- # t = classification_report(
- # self.total_label_train,
- # self.total_predict_train,
- # target_names=self.punc_list)
- # self.logger.info(t)
- except Exception as e:
- self.logger.error(e)
- raise e
-
- total_loss, F1_score = self.valid()
- self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
- format(self.epoch, total_loss, F1_score))
-
- self.save(
- tag=self.epoch, infos={"val_loss": total_loss,
- "F1": F1_score})
- # step lr every epoch
- self.lr_scheduler.step()
- self.new_epoch()
-
- def run(self):
- """The routine of the experiment after setup. This method is intended
- to be used by the user.
- """
- try:
- self.train()
- except KeyboardInterrupt:
- self.logger.info("Training was aborted by keybord interrupt.")
- self.save()
- exit(-1)
- finally:
- self.destory()
- self.logger.info("Training Done.")
-
- def setup_output_dir(self):
- """Create a directory used for output.
- """
- # output dir
- output_dir = Path(self.args.output_dir).expanduser()
- output_dir.mkdir(parents=True, exist_ok=True)
-
- self.output_dir = output_dir
-
- def setup_log_dir(self):
- """Create a directory used for logging.
- """
- # log dir
- log_dir = Path(self.args.log_dir).expanduser()
- log_dir.mkdir(parents=True, exist_ok=True)
-
- self.log_dir = log_dir
-
- def setup_checkpointer(self):
- """Create a directory used to save checkpoints into.
-
- It is "checkpoints" inside the output directory.
- """
- # checkpoint dir
- self.checkpointer = Checkpoint(self.config["checkpoint"]["kbest_n"],
- self.config["checkpoint"]["latest_n"])
-
- checkpoint_dir = self.output_dir / "checkpoints"
- checkpoint_dir.mkdir(exist_ok=True)
-
- self.checkpoint_dir = checkpoint_dir
-
- def setup_logger(self):
- LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
- format_str = logging.Formatter(
- '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
- )
- logging.basicConfig(
- filename=self.config["training"]["log_path"],
- level=logging.INFO,
- format=LOG_FORMAT)
- self.logger = logging.getLogger(__name__)
-
- self.logger.setLevel(logging.INFO)
- sh = logging.StreamHandler()
- sh.setFormatter(format_str)
- self.logger.addHandler(sh)
-
- self.logger.info('info')
-
- @mp_tools.rank_zero_only
- def destory(self):
- pass
-
- @mp_tools.rank_zero_only
- def dump_config(self):
- """Save the configuration used for this experiment.
-
- It is saved in to ``config.yaml`` in the output directory at the
- beginning of the experiment.
- """
- with open(self.output_dir / "config.yaml", "wt") as f:
- print(self.config, file=f)
-
- def train_batch(self, batch_index, batch_data, msg):
- start = time.time()
-
- input, label = batch_data
- label = paddle.reshape(label, shape=[-1])
- y, logit = self.model(input)
- pred = paddle.argmax(logit, axis=1)
- self.total_label_train.extend(label.numpy().tolist())
- self.total_predict_train.extend(pred.numpy().tolist())
- loss = self.crit(y, label)
-
- loss.backward()
- layer_tools.print_grads(self.model, print_func=None)
- self.optimizer.step()
- self.optimizer.clear_grad()
- iteration_time = time.time() - start
-
- losses_np = {
- "train_loss": float(loss),
- }
- msg += "train time: {:>.3f}s, ".format(iteration_time)
- msg += "batch size: {}, ".format(self.config["data"]["batch_size"])
- msg += ", ".join("{}: {:>.6f}".format(k, v)
- for k, v in losses_np.items())
- self.logger.info(msg)
- self.iteration += 1
-
- @paddle.no_grad()
- def valid(self):
- self.logger.info(
- f"Valid Total Examples: {len(self.valid_loader.dataset)}")
- self.model.eval()
- valid_losses = defaultdict(list)
- num_seen_utts = 1
- total_loss = 0.0
- valid_total_label = []
- valid_total_predict = []
- for i, batch in enumerate(self.valid_loader):
- input, label = batch
- label = paddle.reshape(label, shape=[-1])
- y, logit = self.model(input)
- pred = paddle.argmax(logit, axis=1)
- valid_total_label.extend(label.numpy().tolist())
- valid_total_predict.extend(pred.numpy().tolist())
- loss = self.crit(y, label)
-
- if paddle.isfinite(loss):
- num_utts = batch[1].shape[0]
- num_seen_utts += num_utts
- total_loss += float(loss) * num_utts
- valid_losses["val_loss"].append(float(loss))
-
- if (i + 1) % self.config["training"]["log_interval"] == 0:
- valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
- valid_dump["val_history_loss"] = total_loss / num_seen_utts
-
- # logging
- msg = f"Valid: Rank: {dist.get_rank()}, "
- msg += "epoch: {}, ".format(self.epoch)
- msg += "step: {}, ".format(self.iteration)
- msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
- msg += ", ".join("{}: {:>.6f}".format(k, v)
- for k, v in valid_dump.items())
- self.logger.info(msg)
-
- self.logger.info("Rank {} Val info val_loss {}".format(
- dist.get_rank(), total_loss / num_seen_utts))
- F1_score = f1_score(
- valid_total_label, valid_total_predict, average="macro")
- return total_loss / num_seen_utts, F1_score
-
- def setup_model(self):
- config = self.config
-
- model = DefinedClassifier[self.config["model_type"]](
- **self.config["model_params"])
- self.crit = DefinedLoss[self.config["loss_type"]](**self.config[
- "loss"]) if "loss_type" in self.config else DefinedLoss["ce"]()
-
- if self.parallel:
- model = paddle.DataParallel(model)
-
- # self.logger.info(f"{model}")
- # layer_tools.print_params(model, self.logger.info)
-
- lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
- learning_rate=config["training"]["lr"],
- gamma=config["training"]["lr_decay"],
- verbose=True)
- optimizer = paddle.optimizer.Adam(
- learning_rate=lr_scheduler,
- parameters=model.parameters(),
- weight_decay=paddle.regularizer.L2Decay(
- config["training"]["weight_decay"]))
-
- self.model = model
- self.optimizer = optimizer
- self.lr_scheduler = lr_scheduler
- self.logger.info("Setup model/criterion/optimizer/lr_scheduler!")
-
- def setup_dataloader(self):
- config = self.config["data"].copy()
- train_dataset = DefinedDataset[config["dataset_type"]](
- train_path=config["train_path"], **config["data_params"])
- dev_dataset = DefinedDataset[config["dataset_type"]](
- train_path=config["dev_path"], **config["data_params"])
-
- self.train_loader = DataLoader(
- train_dataset,
- num_workers=config["num_workers"],
- batch_size=config["batch_size"])
- self.valid_loader = DataLoader(
- dev_dataset,
- batch_size=config["batch_size"],
- shuffle=False,
- drop_last=False,
- num_workers=config["num_workers"])
- self.logger.info("Setup train/valid Dataloader!")
-
-
-class Tester(Trainer):
- def __init__(self, config, args):
- super().__init__(config, args)
-
- @mp_tools.rank_zero_only
- @paddle.no_grad()
- def test(self):
- self.logger.info(
- f"Test Total Examples: {len(self.test_loader.dataset)}")
- self.punc_list = []
- for i in range(len(self.test_loader.dataset.id2punc)):
- self.punc_list.append(self.test_loader.dataset.id2punc[i])
- self.model.eval()
- test_total_label = []
- test_total_predict = []
- with open(self.args.result_file, 'w') as fout:
- for i, batch in enumerate(self.test_loader):
- input, label = batch
- label = paddle.reshape(label, shape=[-1])
- y, logit = self.model(input)
- pred = paddle.argmax(logit, axis=1)
- test_total_label.extend(label.numpy().tolist())
- test_total_predict.extend(pred.numpy().tolist())
-
- # logging
- msg = "Test: "
- msg += "epoch: {}, ".format(self.epoch)
- msg += "step: {}, ".format(self.iteration)
- self.logger.info(msg)
- t = classification_report(
- test_total_label, test_total_predict, target_names=self.punc_list)
- print(t)
- t2 = self.evaluation(test_total_label, test_total_predict)
- print(t2)
-
- def evaluation(self, y_pred, y_test):
- precision, recall, f1, _ = precision_recall_fscore_support(
- y_test, y_pred, average=None, labels=[1, 2, 3])
- overall = precision_recall_fscore_support(
- y_test, y_pred, average='macro', labels=[1, 2, 3])
- result = pd.DataFrame(
- np.array([precision, recall, f1]),
- columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:],
- index=['Precision', 'Recall', 'F1'])
- result['OVERALL'] = overall[:3]
- return result
-
- def run_test(self):
- self.resume_or_scratch()
- try:
- self.test()
- except KeyboardInterrupt:
- self.logger.info("Testing was aborted by keybord interrupt.")
- exit(-1)
-
- def setup(self):
- """Setup the experiment.
- """
- if self.args.ngpu > 0:
- paddle.set_device('gpu')
- else:
- paddle.set_device('cpu')
- self.setup_logger()
- self.setup_output_dir()
- self.setup_checkpointer()
-
- self.setup_dataloader()
- self.setup_model()
-
- self.iteration = 0
- self.epoch = 0
-
- def setup_model(self):
- config = self.config
- model = DefinedClassifier[self.config["model_type"]](
- **self.config["model_params"])
-
- self.model = model
- self.logger.info("Setup model!")
-
- def setup_dataloader(self):
- config = self.config["data"].copy()
-
- test_dataset = DefinedDataset[config["dataset_type"]](
- train_path=config["test_path"], **config["data_params"])
-
- self.test_loader = DataLoader(
- test_dataset,
- batch_size=config["batch_size"],
- shuffle=False,
- drop_last=False)
- self.logger.info("Setup test Dataloader!")
-
- def setup_output_dir(self):
- """Create a directory used for output.
- """
- # output dir
- if self.args.output_dir:
- output_dir = Path(self.args.output_dir).expanduser()
- output_dir.mkdir(parents=True, exist_ok=True)
- else:
- output_dir = Path(
- self.args.checkpoint_path).expanduser().parent.parent
- output_dir.mkdir(parents=True, exist_ok=True)
-
- self.output_dir = output_dir
-
- def setup_logger(self):
- LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
- format_str = logging.Formatter(
- '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
- )
- logging.basicConfig(
- filename=self.config["testing"]["log_path"],
- level=logging.INFO,
- format=LOG_FORMAT)
- self.logger = logging.getLogger(__name__)
-
- self.logger.setLevel(logging.INFO)
- sh = logging.StreamHandler()
- sh.setFormatter(format_str)
- self.logger.addHandler(sh)
-
- self.logger.info('info')
diff --git a/paddlespeech/text/utils/default_parser.py b/paddlespeech/text/utils/default_parser.py
deleted file mode 100644
index 469157a6..00000000
--- a/paddlespeech/text/utils/default_parser.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import argparse
-
-
-def default_argument_parser():
- r"""A simple yet genral argument parser for experiments with t2s.
-
- This is used in examples with t2s. And it is intended to be used by
- other experiments with t2s. It requires a minimal set of command line
- arguments to start a training script.
-
- The ``--config`` and ``--opts`` are used for overwrite the deault
- configuration.
-
- The ``--data`` and ``--output`` specifies the data path and output path.
- Resuming training from existing progress at the output directory is the
- intended default behavior.
-
- The ``--checkpoint_path`` specifies the checkpoint to load from.
-
- The ``--ngpu`` specifies how to run the training.
-
-
- See Also
- --------
- paddlespeech.t2s.training.experiment
- Returns
- -------
- argparse.ArgumentParser
- the parser
- """
- parser = argparse.ArgumentParser()
-
- # yapf: disable
- # data and output
- parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
- parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.")
- # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
- parser.add_argument("--output_dir", metavar="OUTPUT_DIR", help="path to save checkpoint.")
- parser.add_argument("--log_dir", metavar="LOG_DIR", help="path to save logs.")
-
- # load from saved checkpoint
- parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
-
- # save jit model to
- parser.add_argument("--export_path", type=str, help="path of the jit model to save")
-
- # save asr result to
- parser.add_argument("--result_file", type=str, help="path of save the asr result")
-
- # running
- parser.add_argument("--ngpu", type=int, default=1, help="number of parallel processes to use. if ngpu=0, using cpu.")
-
- # overwrite extra config and default config
- # parser.add_argument("--opts", nargs=argparse.REMAINDER,
- # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
- parser.add_argument("--opts", type=str, default=[], nargs='+',
- help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
- # yapd: enable
-
- return parser