|
|
@ -84,8 +84,7 @@ def train_sp(args, config):
|
|
|
|
mlm_prob=config.mlm_prob,
|
|
|
|
mlm_prob=config.mlm_prob,
|
|
|
|
mean_phn_span=config.mean_phn_span,
|
|
|
|
mean_phn_span=config.mean_phn_span,
|
|
|
|
seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
|
|
|
|
seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
|
|
|
|
text_masking=config["model"]["text_masking"],
|
|
|
|
text_masking=config["model"]["text_masking"])
|
|
|
|
epoch=config["max_epoch"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_sampler = DistributedBatchSampler(
|
|
|
|
train_sampler = DistributedBatchSampler(
|
|
|
|
train_dataset,
|
|
|
|
train_dataset,
|
|
|
|