Merge pull request #2548 from WongLaw/rhy

[Text]Add Rhythm Prediction Function
pull/2585/head
TianYuan 2 years ago committed by GitHub
commit 03950a0fef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,41 @@
# Rhythm Prediction with CSMSC and AiShell3
## Get Started
### Data Preprocessing
```bash
./run.sh --stage 0 --stop-stage 0
```
### Model Training
```bash
./run.sh --stage 1 --stop-stage 1
```
### Testing
```bash
./run.sh --stage 2 --stop-stage 2
```
### Punctuation Restoration
```bash
./run.sh --stage 3 --stop-stage 3
```
## Pretrained Model
The pretrained model can be downloaded here:
[ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip)
And you should put it into `exp/${YOUREXP}/checkpoints` folder.
## Rhythm mapping
Four punctuation marks are used to denote the rhythm marks respectively:
|ryh_token|csmsc|aishll3|
|:---: |:---: |:---: |
|%|#1|%|
|`|#2||
|~|#3||
|$|#4|$|
## Prediction Results
| | #1 | #2 | #3 | #4 |
|:-----:|:-----:|:-----:|:-----:|:-----:|
|Precision |0.90 |0.66 |0.91 |0.90|
|Recall |0.92 |0.62 |0.83 |0.85|
|F1 |0.91 |0.64 |0.87 |0.87|

@ -0,0 +1,44 @@
###########################################################
# DATA SETTING #
###########################################################
dataset_type: Ernie
train_path: data/train.txt
dev_path: data/dev.txt
test_path: data/test.txt
batch_size: 64
num_workers: 2
data_params:
pretrained_token: ernie-1.0
punc_path: data/rhy_token
seq_len: 100
###########################################################
# MODEL SETTING #
###########################################################
model_type: ErnieLinear
model:
pretrained_token: ernie-1.0
num_classes: 5
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer_params:
weight_decay: 1.0e-6 # weight decay coefficient.
scheduler_params:
learning_rate: 1.0e-5 # learning rate.
gamma: 0.9999 # scheduler gamma must between(0.0, 1.0) and closer to 1.0 is better.
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 20
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
num_snapshots: 10 # max number of snapshots to keep while training
seed: 42 # random seed for paddle, random, and np.random

@ -0,0 +1,26 @@
#!/bin/bash
if [ ! -f 000001-010000.txt ]; then
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/000001-010000.txt
fi
if [ ! -f label_train-set.txt ]; then
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/label_train-set.txt
fi
aishell_data=$1
csmsc_data=$2
processed_path=$3
python3 ./local/pre_for_sp_csmsc.py \
--data=${csmsc_data} \
--processed_path=${processed_path}
python3 ./local/pre_for_sp_aishell.py \
--data=${aishell_data} \
--processed_path=${processed_path}
echo "Finish data preparation."
exit 0

@ -0,0 +1,51 @@
import argparse
import os
import re
# This is the replacement for rhythm labels to predict.
# 韵律标签的代替
replace_ = {"#1": "%", "#2": "`", "#3": "~", "#4": "$"}
def replace_rhy_with_punc(line):
# r'[:、,;。?!,.:;"?!”’《》【】<=>{}()#&@“”^_|…\\]%*$', '', line) #参考checkcheck_oov.py,
line = re.sub(r'[:、,;。?!,.:;"?!’《》【】<=>{}()#&@“”^_|…\\]%*$', '', line)
for r in replace_.keys():
if r in line:
line = line.replace(r, replace_[r])
return line
def pre_and_write(data, file):
with open(file, 'a') as rf:
for d in data:
d = d.split('|')[2].strip()
# d = replace_rhy_with_punc(d)
d = ' '.join(d) + ' \n'
rf.write(d)
def main():
parser = argparse.ArgumentParser(
description="Train a Rhy prediction model.")
parser.add_argument("--data", type=str, default="label_train-set.txt")
parser.add_argument(
"--processed_path", type=str, default="../data/rhy_predict")
args = parser.parse_args()
os.makedirs(args.processed_path, exist_ok=True)
with open(args.data) as rf:
text = rf.readlines()[5:]
len_ = len(text)
lens = [int(len_ * 0.9), int(len_ * 0.05), int(len_ * 0.05)]
files = ['train.txt', 'test.txt', 'dev.txt']
i = 0
for l_, file in zip(lens, files):
file = os.path.join(args.processed_path, file)
pre_and_write(text[i:i + l_], file)
i = i + l_
if __name__ == "__main__":
main()

@ -0,0 +1,51 @@
import argparse
import os
import re
replace_ = {"#1": "%", "#2": "`", "#3": "~", "#4": "$"}
def replace_rhy_with_punc(line):
# r'[:、,;。?!,.:;"?!”’《》【】<=>{}()#&@“”^_|…\\]%*$', '', line) #参考checkcheck_oov.py,
line = re.sub(r'^$\*%', '', line)
for r in replace_.keys():
if r in line:
line = line.replace(r, replace_[r])
return line
def pre_and_write(data, file):
with open(file, 'w') as rf:
for d in data:
d = d.split('\t')[1].strip()
d = replace_rhy_with_punc(d)
d = ' '.join(d) + ' \n'
rf.write(d)
def main():
parser = argparse.ArgumentParser(
description="Train a Rhy prediction model.")
parser.add_argument("--data", type=str, default="label_train-set.txt")
parser.add_argument(
"--processed_path", type=str, default="../data/rhy_predict")
args = parser.parse_args()
print(args.data, args.processed_path)
os.makedirs(args.processed_path, exist_ok=True)
with open(args.data) as rf:
rf = rf.readlines()
text = rf[0::2]
len_ = len(text)
lens = [int(len_ * 0.9), int(len_ * 0.05), int(len_ * 0.05)]
files = ['train.txt', 'test.txt', 'dev.txt']
i = 0
for l_, file in zip(lens, files):
file = os.path.join(args.processed_path, file)
pre_and_write(text[i:i + l_], file)
i = i + l_
if __name__ == "__main__":
main()

@ -0,0 +1,12 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
text=$4
ckpt_prefix=${ckpt_name%.*}
python3 ${BIN_DIR}/punc_restore.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--text=${text}

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
print_eval=$4
ckpt_prefix=${ckpt_name%.*}
python3 ${BIN_DIR}/test.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--print_eval=${print_eval}

@ -0,0 +1,9 @@
#!/bin/bash
config_path=$1
train_output_path=$2
python3 ${BIN_DIR}/train.py \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1

@ -0,0 +1,14 @@
#!/bin/bash
export MAIN_ROOT=${PWD}/../../../
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=ernie_linear
export BIN_DIR=${MAIN_ROOT}/paddlespeech/text/exps/${MODEL}

@ -0,0 +1,40 @@
#!/bin/bash
set -e
source path.sh
gpus=0
stage=0
stop_stage=100
aishell_data=label_train-set.txt
csmsc_data=000001-010000.txt
processed_path=data
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_2600.pdz
text=我们城市的复苏有赖于他强有力的政策。
print_eval=false
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/data.sh ${aishell_data} ${csmsc_data} ${processed_path}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} ${print_eval} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/rhy_predict.sh ${conf_path} ${train_output_path} ${ckpt_name} ${text}|| exit -1
fi

@ -23,6 +23,7 @@ from sklearn.metrics import classification_report
from sklearn.metrics import precision_recall_fscore_support
from yacs.config import CfgNode
from paddlespeech.t2s.utils import str2bool
from paddlespeech.text.models.ernie_linear import ErnieLinear
from paddlespeech.text.models.ernie_linear import PuncDataset
from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
@ -91,9 +92,10 @@ def test(args):
t = classification_report(
test_total_label, test_total_predict, target_names=punc_list)
print(t)
t2 = evaluation(test_total_label, test_total_predict)
print('=========================================================')
print(t2)
if args.print_eval:
t2 = evaluation(test_total_label, test_total_predict)
print('=========================================================')
print(t2)
def main():
@ -101,6 +103,7 @@ def main():
parser = argparse.ArgumentParser(description="Test a ErnieLinear model.")
parser.add_argument("--config", type=str, help="ErnieLinear config file.")
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
parser.add_argument("--print_eval", type=str2bool, default=True)
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")

Loading…
Cancel
Save