Revised structure of rhythm prediction, test=tts

pull/2548/head
WongLaw 2 years ago
parent 868d9d933c
commit 72bbabbf79

@ -20,6 +20,22 @@
## Pretrained Model ## Pretrained Model
The pretrained model can be downloaded here: The pretrained model can be downloaded here:
[ernie-1.0_aishellcsmsc_ckpt_1.3.0](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip) [ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip)
And you should put it into `exp/YOUREXP/checkpoints` folder. And you should put it into `exp/YOUREXP/checkpoints` folder.
## Rhythm mapping
Four punctuation marks are used to denote the rhythm marks respectively:
|ryh_token|csmsc|aishll3|
|:---: |:---: |:---: |
|%|#1|%|
|`|#2||
|~|#3||
|$|#4|$|
## Prediction Results
| | #1 | #2 | #3 | #4 |
|:-----:|:-----:|:-----:|:-----:|:-----:|
|Precision |0.90 |0.66 |0.91 |0.90|
|Recall |0.92 |0.62 |0.83 |0.85|
|F1 |0.91 |0.64 |0.87 |0.87|

@ -2,14 +2,14 @@
# DATA SETTING # # DATA SETTING #
########################################################### ###########################################################
dataset_type: Ernie dataset_type: Ernie
train_path: data/rhy_predict/train.txt train_path: data/train.txt
dev_path: data/rhy_predict/dev.txt dev_path: data/dev.txt
test_path: data/rhy_predict/test.txt test_path: data/test.txt
batch_size: 64 batch_size: 64
num_workers: 2 num_workers: 2
data_params: data_params:
pretrained_token: ernie-1.0 pretrained_token: ernie-1.0
punc_path: data/rhy_predict/rhy_token punc_path: data/rhy_token
seq_len: 100 seq_len: 100

@ -1,29 +0,0 @@
import argparse
def process_sentence(line):
if line == '':
return ''
res = line[0]
for i in range(1, len(line)):
res += (' ' + line[i])
return res
if __name__ == "__main__":
paser = argparse.ArgumentParser(description="Input filename")
paser.add_argument('-input_file')
paser.add_argument('-output_file')
sentence_cnt = 0
args = paser.parse_args()
with open(args.input_file, 'r') as f:
with open(args.output_file, 'w') as write_f:
while True:
line = f.readline()
if line:
sentence_cnt += 1
write_f.write(process_sentence(line))
else:
break
print('preprocess over')
print('total sentences number:', sentence_cnt)

@ -3,9 +3,11 @@
config_path=$1 config_path=$1
train_output_path=$2 train_output_path=$2
ckpt_name=$3 ckpt_name=$3
print_eval=$4
ckpt_prefix=${ckpt_name%.*} ckpt_prefix=${ckpt_name%.*}
python3 ${BIN_DIR}/test.py \ python3 ${BIN_DIR}/test.py \
--config=${config_path} \ --config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--print_eval=${print_eval}

@ -2,18 +2,19 @@
set -e set -e
source path.sh source path.sh
gpus=1 gpus=0
stage=0 stage=0
stop_stage=100 stop_stage=100
aishell_data=label_train-set.txt aishell_data=label_train-set.txt
csmsc_data=000001-010000.txt csmsc_data=000001-010000.txt
processed_path=data/rhy_predict processed_path=data
conf_path=conf/default.yaml conf_path=conf/default.yaml
train_output_path=exp/rhy train_output_path=exp/rhy
ckpt_name=snapshot_iter_2600.pdz ckpt_name=snapshot_iter_2600.pdz
text=我们城市的复苏有赖于他强有力的政策。 text=我们城市的复苏有赖于他强有力的政策。
print_eval=false
# with the following command, you can choose the stage range you want to run # with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0` # such as `./run.sh --stage 0 --stop-stage 0`
@ -31,7 +32,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} ${print_eval} || exit -1
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then

@ -26,6 +26,8 @@ from yacs.config import CfgNode
from paddlespeech.text.models.ernie_linear import ErnieLinear from paddlespeech.text.models.ernie_linear import ErnieLinear
from paddlespeech.text.models.ernie_linear import PuncDataset from paddlespeech.text.models.ernie_linear import PuncDataset
from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
from paddlespeech.t2s.utils import str2bool
DefinedClassifier = { DefinedClassifier = {
'ErnieLinear': ErnieLinear, 'ErnieLinear': ErnieLinear,
@ -91,9 +93,12 @@ def test(args):
t = classification_report( t = classification_report(
test_total_label, test_total_predict, target_names=punc_list) test_total_label, test_total_predict, target_names=punc_list)
print(t) print(t)
if args.print_eval:
t2 = evaluation(test_total_label, test_total_predict) t2 = evaluation(test_total_label, test_total_predict)
print('=========================================================') print('=========================================================')
print(t2) print(t2)
else:
pass
def main(): def main():
@ -101,6 +106,10 @@ def main():
parser = argparse.ArgumentParser(description="Test a ErnieLinear model.") parser = argparse.ArgumentParser(description="Test a ErnieLinear model.")
parser.add_argument("--config", type=str, help="ErnieLinear config file.") parser.add_argument("--config", type=str, help="ErnieLinear config file.")
parser.add_argument("--checkpoint", type=str, help="snapshot to load.") parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
parser.add_argument(
"--print_eval",
type=str2bool,
default=False)
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")

Loading…
Cancel
Save