fix mlm_prob, test=tts

pull/2117/head
TianYuan 3 years ago
parent c1395e3a05
commit 97965f4c37

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--ngpu=2 \
--phones-dict=dump/phone_id_map.txt

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--ngpu=2 \
--phones-dict=dump/phone_id_map.txt

@ -79,7 +79,7 @@ grad_clip: 1.0
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 200
max_epoch: 600
num_snapshots: 5
###########################################################
@ -160,4 +160,4 @@ token_list:
- UH0
- AW0
- OY0
- <sos/eos>
- <sos/eos>

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--ngpu=2 \
--phones-dict=dump/phone_id_map.txt

@ -29,20 +29,13 @@ from paddlespeech.t2s.modules.nets_utils import phones_text_masking
# 因为要传参数,所以需要额外构建
def build_erniesat_collate_fn(
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False,
epoch: int=-1, ):
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
def build_erniesat_collate_fn(mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False):
return ErnieSATCollateFn(
mlm_prob=mlm_prob * mlm_prob_factor,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
seg_emb=seg_emb,
text_masking=text_masking)

@ -73,8 +73,7 @@ def evaluate(args):
mlm_prob=erniesat_config.mlm_prob,
mean_phn_span=erniesat_config.mean_phn_span,
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
text_masking=False,
epoch=-1)
text_masking=False)
gen_raw = True
erniesat_mu, erniesat_std = np.load(args.erniesat_stat)

@ -84,8 +84,7 @@ def train_sp(args, config):
mlm_prob=config.mlm_prob,
mean_phn_span=config.mean_phn_span,
seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
text_masking=config["model"]["text_masking"],
epoch=config["max_epoch"])
text_masking=config["model"]["text_masking"])
train_sampler = DistributedBatchSampler(
train_dataset,

Loading…
Cancel
Save