diff --git a/examples/other/rhy/README.md b/examples/other/rhy/README.md new file mode 100644 index 00000000..5a70c1d2 --- /dev/null +++ b/examples/other/rhy/README.md @@ -0,0 +1,41 @@ +# Rhythm Prediction with CSMSC and AiShell3 + +## Get Started +### Data Preprocessing +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Model Training +```bash +./run.sh --stage 1 --stop-stage 1 +``` +### Testing +```bash +./run.sh --stage 2 --stop-stage 2 +``` +### Punctuation Restoration +```bash +./run.sh --stage 3 --stop-stage 3 +``` +## Pretrained Model +The pretrained model can be downloaded here: + +[ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip) + +And you should put it into `exp/${YOUREXP}/checkpoints` folder. + +## Rhythm mapping +Four punctuation marks are used to denote the rhythm marks respectively: +|ryh_token|csmsc|aishll3| +|:---: |:---: |:---: | +|%|#1|%| +|`|#2|| +|~|#3|| +|$|#4|$| + +## Prediction Results +| | #1 | #2 | #3 | #4 | +|:-----:|:-----:|:-----:|:-----:|:-----:| +|Precision |0.90 |0.66 |0.91 |0.90| +|Recall |0.92 |0.62 |0.83 |0.85| +|F1 |0.91 |0.64 |0.87 |0.87| diff --git a/examples/other/rhy/conf/default.yaml b/examples/other/rhy/conf/default.yaml new file mode 100644 index 00000000..1eb90f11 --- /dev/null +++ b/examples/other/rhy/conf/default.yaml @@ -0,0 +1,44 @@ +########################################################### +# DATA SETTING # +########################################################### +dataset_type: Ernie +train_path: data/train.txt +dev_path: data/dev.txt +test_path: data/test.txt +batch_size: 64 +num_workers: 2 +data_params: + pretrained_token: ernie-1.0 + punc_path: data/rhy_token + seq_len: 100 + + +########################################################### +# MODEL SETTING # +########################################################### +model_type: ErnieLinear +model: + pretrained_token: ernie-1.0 + num_classes: 5 + +########################################################### +# OPTIMIZER SETTING # +########################################################### +optimizer_params: + weight_decay: 1.0e-6 # weight decay coefficient. + +scheduler_params: + learning_rate: 1.0e-5 # learning rate. + gamma: 0.9999 # scheduler gamma must between(0.0, 1.0) and closer to 1.0 is better. + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 20 +num_snapshots: 5 + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/other/rhy/data/rhy_token b/examples/other/rhy/data/rhy_token new file mode 100644 index 00000000..bf1fe253 --- /dev/null +++ b/examples/other/rhy/data/rhy_token @@ -0,0 +1,4 @@ +% +` +~ +$ \ No newline at end of file diff --git a/examples/other/rhy/local/data.sh b/examples/other/rhy/local/data.sh new file mode 100755 index 00000000..93b13487 --- /dev/null +++ b/examples/other/rhy/local/data.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +if [ ! -f 000001-010000.txt ]; then + wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/000001-010000.txt +fi + +if [ ! -f label_train-set.txt ]; then + wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/label_train-set.txt +fi + + +aishell_data=$1 +csmsc_data=$2 +processed_path=$3 + +python3 ./local/pre_for_sp_csmsc.py \ + --data=${csmsc_data} \ + --processed_path=${processed_path} + +python3 ./local/pre_for_sp_aishell.py \ + --data=${aishell_data} \ + --processed_path=${processed_path} + + +echo "Finish data preparation." +exit 0 diff --git a/examples/other/rhy/local/pre_for_sp_aishell.py b/examples/other/rhy/local/pre_for_sp_aishell.py new file mode 100644 index 00000000..a2a71668 --- /dev/null +++ b/examples/other/rhy/local/pre_for_sp_aishell.py @@ -0,0 +1,51 @@ +import argparse +import os +import re + +# This is the replacement for rhythm labels to predict. +# 韵律标签的代替 +replace_ = {"#1": "%", "#2": "`", "#3": "~", "#4": "$"} + + +def replace_rhy_with_punc(line): + # r'[:、,;。?!,.:;"?!”’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line) #参考checkcheck_oov.py, + line = re.sub(r'[:、,;。?!,.:;"?!’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line) + for r in replace_.keys(): + if r in line: + line = line.replace(r, replace_[r]) + return line + + +def pre_and_write(data, file): + with open(file, 'a') as rf: + for d in data: + d = d.split('|')[2].strip() + # d = replace_rhy_with_punc(d) + d = ' '.join(d) + ' \n' + rf.write(d) + + +def main(): + parser = argparse.ArgumentParser( + description="Train a Rhy prediction model.") + parser.add_argument("--data", type=str, default="label_train-set.txt") + parser.add_argument( + "--processed_path", type=str, default="../data/rhy_predict") + args = parser.parse_args() + os.makedirs(args.processed_path, exist_ok=True) + + with open(args.data) as rf: + text = rf.readlines()[5:] + len_ = len(text) + lens = [int(len_ * 0.9), int(len_ * 0.05), int(len_ * 0.05)] + files = ['train.txt', 'test.txt', 'dev.txt'] + + i = 0 + for l_, file in zip(lens, files): + file = os.path.join(args.processed_path, file) + pre_and_write(text[i:i + l_], file) + i = i + l_ + + +if __name__ == "__main__": + main() diff --git a/examples/other/rhy/local/pre_for_sp_csmsc.py b/examples/other/rhy/local/pre_for_sp_csmsc.py new file mode 100644 index 00000000..0a96092c --- /dev/null +++ b/examples/other/rhy/local/pre_for_sp_csmsc.py @@ -0,0 +1,51 @@ +import argparse +import os +import re + +replace_ = {"#1": "%", "#2": "`", "#3": "~", "#4": "$"} + + +def replace_rhy_with_punc(line): + # r'[:、,;。?!,.:;"?!”’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line) #参考checkcheck_oov.py, + line = re.sub(r'^$\*%', '', line) + for r in replace_.keys(): + if r in line: + line = line.replace(r, replace_[r]) + return line + + +def pre_and_write(data, file): + with open(file, 'w') as rf: + for d in data: + d = d.split('\t')[1].strip() + d = replace_rhy_with_punc(d) + d = ' '.join(d) + ' \n' + rf.write(d) + + +def main(): + parser = argparse.ArgumentParser( + description="Train a Rhy prediction model.") + parser.add_argument("--data", type=str, default="label_train-set.txt") + parser.add_argument( + "--processed_path", type=str, default="../data/rhy_predict") + args = parser.parse_args() + print(args.data, args.processed_path) + os.makedirs(args.processed_path, exist_ok=True) + + with open(args.data) as rf: + rf = rf.readlines() + text = rf[0::2] + len_ = len(text) + lens = [int(len_ * 0.9), int(len_ * 0.05), int(len_ * 0.05)] + files = ['train.txt', 'test.txt', 'dev.txt'] + + i = 0 + for l_, file in zip(lens, files): + file = os.path.join(args.processed_path, file) + pre_and_write(text[i:i + l_], file) + i = i + l_ + + +if __name__ == "__main__": + main() diff --git a/examples/other/rhy/local/rhy_predict.sh b/examples/other/rhy/local/rhy_predict.sh new file mode 100755 index 00000000..30a4f12f --- /dev/null +++ b/examples/other/rhy/local/rhy_predict.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +text=$4 +ckpt_prefix=${ckpt_name%.*} + +python3 ${BIN_DIR}/punc_restore.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --text=${text} diff --git a/examples/other/rhy/local/test.sh b/examples/other/rhy/local/test.sh new file mode 100755 index 00000000..bd490b5b --- /dev/null +++ b/examples/other/rhy/local/test.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +print_eval=$4 + +ckpt_prefix=${ckpt_name%.*} + +python3 ${BIN_DIR}/test.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --print_eval=${print_eval} \ No newline at end of file diff --git a/examples/other/rhy/local/train.sh b/examples/other/rhy/local/train.sh new file mode 100755 index 00000000..85227eac --- /dev/null +++ b/examples/other/rhy/local/train.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=1 diff --git a/examples/other/rhy/path.sh b/examples/other/rhy/path.sh new file mode 100755 index 00000000..da790261 --- /dev/null +++ b/examples/other/rhy/path.sh @@ -0,0 +1,14 @@ +#!/bin/bash +export MAIN_ROOT=${PWD}/../../../ + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=ernie_linear +export BIN_DIR=${MAIN_ROOT}/paddlespeech/text/exps/${MODEL} diff --git a/examples/other/rhy/run.sh b/examples/other/rhy/run.sh new file mode 100755 index 00000000..aed58152 --- /dev/null +++ b/examples/other/rhy/run.sh @@ -0,0 +1,40 @@ +#!/bin/bash +set -e +source path.sh + +gpus=0 +stage=0 +stop_stage=100 + +aishell_data=label_train-set.txt +csmsc_data=000001-010000.txt +processed_path=data + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_2600.pdz +text=我们城市的复苏有赖于他强有力的政策。 +print_eval=false + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/data.sh ${aishell_data} ${csmsc_data} ${processed_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} ${print_eval} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/rhy_predict.sh ${conf_path} ${train_output_path} ${ckpt_name} ${text}|| exit -1 +fi \ No newline at end of file diff --git a/paddlespeech/text/exps/ernie_linear/test.py b/paddlespeech/text/exps/ernie_linear/test.py index 4302a1a3..aa172cc6 100644 --- a/paddlespeech/text/exps/ernie_linear/test.py +++ b/paddlespeech/text/exps/ernie_linear/test.py @@ -23,6 +23,7 @@ from sklearn.metrics import classification_report from sklearn.metrics import precision_recall_fscore_support from yacs.config import CfgNode +from paddlespeech.t2s.utils import str2bool from paddlespeech.text.models.ernie_linear import ErnieLinear from paddlespeech.text.models.ernie_linear import PuncDataset from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer @@ -91,9 +92,10 @@ def test(args): t = classification_report( test_total_label, test_total_predict, target_names=punc_list) print(t) - t2 = evaluation(test_total_label, test_total_predict) - print('=========================================================') - print(t2) + if args.print_eval: + t2 = evaluation(test_total_label, test_total_predict) + print('=========================================================') + print(t2) def main(): @@ -101,6 +103,7 @@ def main(): parser = argparse.ArgumentParser(description="Test a ErnieLinear model.") parser.add_argument("--config", type=str, help="ErnieLinear config file.") parser.add_argument("--checkpoint", type=str, help="snapshot to load.") + parser.add_argument("--print_eval", type=str2bool, default=True) parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")