updatte batch_fn train.py, test=doc

pull/2117/head
TianYuan 3 years ago
parent 9d4161ce5f
commit 1bf78fa5c7

@ -68,7 +68,7 @@ def erniesat_batch_fn(examples,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
# fields = ["text", "text_lengths", "speech", "speech_lengths", "align_start", "align_end"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]

@ -116,13 +116,6 @@ def train_sp(args, config):
odim = config.n_mels
model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"])
# model_path = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/ernie_sat/pretrained_model/paddle_checkpoint_en/model.pdparams"
# state_dict = paddle.load(model_path)
# new_state_dict = {}
# for key, value in state_dict.items():
# new_key = "model." + key
# new_state_dict[new_key] = value
# model.set_state_dict(new_state_dict)
if world_size > 1:
model = DataParallel(model)

Loading…
Cancel
Save