fix mlm_prob, test=tts

pull/2117/head
TianYuan 3 years ago
parent c1395e3a05
commit 97965f4c37

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=1 \ --ngpu=2 \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=1 \ --ngpu=2 \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt

@ -79,7 +79,7 @@ grad_clip: 1.0
########################################################### ###########################################################
# TRAINING SETTING # # TRAINING SETTING #
########################################################### ###########################################################
max_epoch: 200 max_epoch: 600
num_snapshots: 5 num_snapshots: 5
########################################################### ###########################################################
@ -160,4 +160,4 @@ token_list:
- UH0 - UH0
- AW0 - AW0
- OY0 - OY0
- <sos/eos> - <sos/eos>

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=1 \ --ngpu=2 \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt

@ -29,20 +29,13 @@ from paddlespeech.t2s.modules.nets_utils import phones_text_masking
# 因为要传参数,所以需要额外构建 # 因为要传参数,所以需要额外构建
def build_erniesat_collate_fn( def build_erniesat_collate_fn(mlm_prob: float=0.8,
mlm_prob: float=0.8, mean_phn_span: int=8,
mean_phn_span: int=8, seg_emb: bool=False,
seg_emb: bool=False, text_masking: bool=False):
text_masking: bool=False,
epoch: int=-1, ):
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
return ErnieSATCollateFn( return ErnieSATCollateFn(
mlm_prob=mlm_prob * mlm_prob_factor, mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span, mean_phn_span=mean_phn_span,
seg_emb=seg_emb, seg_emb=seg_emb,
text_masking=text_masking) text_masking=text_masking)

@ -73,8 +73,7 @@ def evaluate(args):
mlm_prob=erniesat_config.mlm_prob, mlm_prob=erniesat_config.mlm_prob,
mean_phn_span=erniesat_config.mean_phn_span, mean_phn_span=erniesat_config.mean_phn_span,
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
text_masking=False, text_masking=False)
epoch=-1)
gen_raw = True gen_raw = True
erniesat_mu, erniesat_std = np.load(args.erniesat_stat) erniesat_mu, erniesat_std = np.load(args.erniesat_stat)

@ -84,8 +84,7 @@ def train_sp(args, config):
mlm_prob=config.mlm_prob, mlm_prob=config.mlm_prob,
mean_phn_span=config.mean_phn_span, mean_phn_span=config.mean_phn_span,
seg_emb=config.model['enc_input_layer'] == 'sega_mlm', seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
text_masking=config["model"]["text_masking"], text_masking=config["model"]["text_masking"])
epoch=config["max_epoch"])
train_sampler = DistributedBatchSampler( train_sampler = DistributedBatchSampler(
train_dataset, train_dataset,

Loading…
Cancel
Save