From 97965f4c37057ae716b7ee20de4e070ac6610604 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 15 Jul 2022 02:48:58 +0000 Subject: [PATCH] fix mlm_prob, test=tts --- examples/aishell3/ernie_sat/local/train.sh | 2 +- examples/aishell3_vctk/ernie_sat/local/train.sh | 2 +- examples/vctk/ernie_sat/conf/default.yaml | 4 ++-- examples/vctk/ernie_sat/local/train.sh | 2 +- paddlespeech/t2s/datasets/am_batch_fn.py | 17 +++++------------ paddlespeech/t2s/exps/ernie_sat/synthesize.py | 3 +-- paddlespeech/t2s/exps/ernie_sat/train.py | 3 +-- 7 files changed, 12 insertions(+), 21 deletions(-) diff --git a/examples/aishell3/ernie_sat/local/train.sh b/examples/aishell3/ernie_sat/local/train.sh index f90db9150..30720e8f5 100755 --- a/examples/aishell3/ernie_sat/local/train.sh +++ b/examples/aishell3/ernie_sat/local/train.sh @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=2 \ --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/examples/aishell3_vctk/ernie_sat/local/train.sh b/examples/aishell3_vctk/ernie_sat/local/train.sh index f90db9150..30720e8f5 100755 --- a/examples/aishell3_vctk/ernie_sat/local/train.sh +++ b/examples/aishell3_vctk/ernie_sat/local/train.sh @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=2 \ --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/examples/vctk/ernie_sat/conf/default.yaml b/examples/vctk/ernie_sat/conf/default.yaml index 74c847a5f..b61c81703 100644 --- a/examples/vctk/ernie_sat/conf/default.yaml +++ b/examples/vctk/ernie_sat/conf/default.yaml @@ -79,7 +79,7 @@ grad_clip: 1.0 ########################################################### # TRAINING SETTING # ########################################################### -max_epoch: 200 +max_epoch: 600 num_snapshots: 5 ########################################################### @@ -160,4 +160,4 @@ token_list: - UH0 - AW0 - OY0 -- \ No newline at end of file +- diff --git a/examples/vctk/ernie_sat/local/train.sh b/examples/vctk/ernie_sat/local/train.sh index f90db9150..30720e8f5 100755 --- a/examples/vctk/ernie_sat/local/train.sh +++ b/examples/vctk/ernie_sat/local/train.sh @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=2 \ --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 9c964d8e9..05471167f 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -29,20 +29,13 @@ from paddlespeech.t2s.modules.nets_utils import phones_text_masking # 因为要传参数,所以需要额外构建 -def build_erniesat_collate_fn( - mlm_prob: float=0.8, - mean_phn_span: int=8, - seg_emb: bool=False, - text_masking: bool=False, - epoch: int=-1, ): - - if epoch == -1: - mlm_prob_factor = 1 - else: - mlm_prob_factor = 0.8 +def build_erniesat_collate_fn(mlm_prob: float=0.8, + mean_phn_span: int=8, + seg_emb: bool=False, + text_masking: bool=False): return ErnieSATCollateFn( - mlm_prob=mlm_prob * mlm_prob_factor, + mlm_prob=mlm_prob, mean_phn_span=mean_phn_span, seg_emb=seg_emb, text_masking=text_masking) diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize.py b/paddlespeech/t2s/exps/ernie_sat/synthesize.py index 56f26a8bb..2e3582948 100644 --- a/paddlespeech/t2s/exps/ernie_sat/synthesize.py +++ b/paddlespeech/t2s/exps/ernie_sat/synthesize.py @@ -73,8 +73,7 @@ def evaluate(args): mlm_prob=erniesat_config.mlm_prob, mean_phn_span=erniesat_config.mean_phn_span, seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', - text_masking=False, - epoch=-1) + text_masking=False) gen_raw = True erniesat_mu, erniesat_std = np.load(args.erniesat_stat) diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py index 020b0d0fa..5d8eadb68 100644 --- a/paddlespeech/t2s/exps/ernie_sat/train.py +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -84,8 +84,7 @@ def train_sp(args, config): mlm_prob=config.mlm_prob, mean_phn_span=config.mean_phn_span, seg_emb=config.model['enc_input_layer'] == 'sega_mlm', - text_masking=config["model"]["text_masking"], - epoch=config["max_epoch"]) + text_masking=config["model"]["text_masking"]) train_sampler = DistributedBatchSampler( train_dataset,