From 72bbabbf792d597d3716f01f00339d1da8cf2846 Mon Sep 17 00:00:00 2001 From: WongLaw Date: Tue, 25 Oct 2022 04:24:38 +0000 Subject: [PATCH] Revised structure of rhythm prediction, test=tts --- examples/other/rhy/README.md | 18 +++++++++++- examples/other/rhy/conf/default.yaml | 8 ++--- .../rhy/data/{rhy_predict => }/rhy_token | 0 examples/other/rhy/local/preprocess.py | 29 ------------------- examples/other/rhy/local/test.sh | 4 ++- examples/other/rhy/run.sh | 7 +++-- paddlespeech/text/exps/ernie_linear/test.py | 15 ++++++++-- 7 files changed, 40 insertions(+), 41 deletions(-) rename examples/other/rhy/data/{rhy_predict => }/rhy_token (100%) delete mode 100644 examples/other/rhy/local/preprocess.py diff --git a/examples/other/rhy/README.md b/examples/other/rhy/README.md index 10bbdefb..a08a7c76 100644 --- a/examples/other/rhy/README.md +++ b/examples/other/rhy/README.md @@ -20,6 +20,22 @@ ## Pretrained Model The pretrained model can be downloaded here: -[ernie-1.0_aishellcsmsc_ckpt_1.3.0](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip) +[ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip) And you should put it into `exp/YOUREXP/checkpoints` folder. + +## Rhythm mapping +Four punctuation marks are used to denote the rhythm marks respectively: +|ryh_token|csmsc|aishll3| +|:---: |:---: |:---: | +|%|#1|%| +|`|#2|| +|~|#3|| +|$|#4|$| + +## Prediction Results +| | #1 | #2 | #3 | #4 | +|:-----:|:-----:|:-----:|:-----:|:-----:| +|Precision |0.90 |0.66 |0.91 |0.90| +|Recall |0.92 |0.62 |0.83 |0.85| +|F1 |0.91 |0.64 |0.87 |0.87| diff --git a/examples/other/rhy/conf/default.yaml b/examples/other/rhy/conf/default.yaml index 566aeb97..1eb90f11 100644 --- a/examples/other/rhy/conf/default.yaml +++ b/examples/other/rhy/conf/default.yaml @@ -2,14 +2,14 @@ # DATA SETTING # ########################################################### dataset_type: Ernie -train_path: data/rhy_predict/train.txt -dev_path: data/rhy_predict/dev.txt -test_path: data/rhy_predict/test.txt +train_path: data/train.txt +dev_path: data/dev.txt +test_path: data/test.txt batch_size: 64 num_workers: 2 data_params: pretrained_token: ernie-1.0 - punc_path: data/rhy_predict/rhy_token + punc_path: data/rhy_token seq_len: 100 diff --git a/examples/other/rhy/data/rhy_predict/rhy_token b/examples/other/rhy/data/rhy_token similarity index 100% rename from examples/other/rhy/data/rhy_predict/rhy_token rename to examples/other/rhy/data/rhy_token diff --git a/examples/other/rhy/local/preprocess.py b/examples/other/rhy/local/preprocess.py deleted file mode 100644 index 3df07c72..00000000 --- a/examples/other/rhy/local/preprocess.py +++ /dev/null @@ -1,29 +0,0 @@ -import argparse - - -def process_sentence(line): - if line == '': - return '' - res = line[0] - for i in range(1, len(line)): - res += (' ' + line[i]) - return res - - -if __name__ == "__main__": - paser = argparse.ArgumentParser(description="Input filename") - paser.add_argument('-input_file') - paser.add_argument('-output_file') - sentence_cnt = 0 - args = paser.parse_args() - with open(args.input_file, 'r') as f: - with open(args.output_file, 'w') as write_f: - while True: - line = f.readline() - if line: - sentence_cnt += 1 - write_f.write(process_sentence(line)) - else: - break - print('preprocess over') - print('total sentences number:', sentence_cnt) diff --git a/examples/other/rhy/local/test.sh b/examples/other/rhy/local/test.sh index 94e508b5..bd490b5b 100755 --- a/examples/other/rhy/local/test.sh +++ b/examples/other/rhy/local/test.sh @@ -3,9 +3,11 @@ config_path=$1 train_output_path=$2 ckpt_name=$3 +print_eval=$4 ckpt_prefix=${ckpt_name%.*} python3 ${BIN_DIR}/test.py \ --config=${config_path} \ - --checkpoint=${train_output_path}/checkpoints/${ckpt_name} + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --print_eval=${print_eval} \ No newline at end of file diff --git a/examples/other/rhy/run.sh b/examples/other/rhy/run.sh index d4484c8a..7e0108e4 100755 --- a/examples/other/rhy/run.sh +++ b/examples/other/rhy/run.sh @@ -2,18 +2,19 @@ set -e source path.sh -gpus=1 +gpus=0 stage=0 stop_stage=100 aishell_data=label_train-set.txt csmsc_data=000001-010000.txt -processed_path=data/rhy_predict +processed_path=data conf_path=conf/default.yaml train_output_path=exp/rhy ckpt_name=snapshot_iter_2600.pdz text=我们城市的复苏有赖于他强有力的政策。 +print_eval=false # with the following command, you can choose the stage range you want to run # such as `./run.sh --stage 0 --stop-stage 0` @@ -31,7 +32,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 + CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} ${print_eval} || exit -1 fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/paddlespeech/text/exps/ernie_linear/test.py b/paddlespeech/text/exps/ernie_linear/test.py index 4302a1a3..8abf77f4 100644 --- a/paddlespeech/text/exps/ernie_linear/test.py +++ b/paddlespeech/text/exps/ernie_linear/test.py @@ -26,6 +26,8 @@ from yacs.config import CfgNode from paddlespeech.text.models.ernie_linear import ErnieLinear from paddlespeech.text.models.ernie_linear import PuncDataset from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer +from paddlespeech.t2s.utils import str2bool + DefinedClassifier = { 'ErnieLinear': ErnieLinear, @@ -91,9 +93,12 @@ def test(args): t = classification_report( test_total_label, test_total_predict, target_names=punc_list) print(t) - t2 = evaluation(test_total_label, test_total_predict) - print('=========================================================') - print(t2) + if args.print_eval: + t2 = evaluation(test_total_label, test_total_predict) + print('=========================================================') + print(t2) + else: + pass def main(): @@ -101,6 +106,10 @@ def main(): parser = argparse.ArgumentParser(description="Test a ErnieLinear model.") parser.add_argument("--config", type=str, help="ErnieLinear config file.") parser.add_argument("--checkpoint", type=str, help="snapshot to load.") + parser.add_argument( + "--print_eval", + type=str2bool, + default=False) parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")