|
|
|
@ -29,20 +29,13 @@ from paddlespeech.t2s.modules.nets_utils import phones_text_masking
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 因为要传参数,所以需要额外构建
|
|
|
|
|
def build_erniesat_collate_fn(
|
|
|
|
|
mlm_prob: float=0.8,
|
|
|
|
|
def build_erniesat_collate_fn(mlm_prob: float=0.8,
|
|
|
|
|
mean_phn_span: int=8,
|
|
|
|
|
seg_emb: bool=False,
|
|
|
|
|
text_masking: bool=False,
|
|
|
|
|
epoch: int=-1, ):
|
|
|
|
|
|
|
|
|
|
if epoch == -1:
|
|
|
|
|
mlm_prob_factor = 1
|
|
|
|
|
else:
|
|
|
|
|
mlm_prob_factor = 0.8
|
|
|
|
|
text_masking: bool=False):
|
|
|
|
|
|
|
|
|
|
return ErnieSATCollateFn(
|
|
|
|
|
mlm_prob=mlm_prob * mlm_prob_factor,
|
|
|
|
|
mlm_prob=mlm_prob,
|
|
|
|
|
mean_phn_span=mean_phn_span,
|
|
|
|
|
seg_emb=seg_emb,
|
|
|
|
|
text_masking=text_masking)
|
|
|
|
|