|
|
|
@ -292,7 +292,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=collate_fn_dev)
|
|
|
|
|
collate_fn=collate_fn_dev,
|
|
|
|
|
num_workers=config.collator.num_workers, )
|
|
|
|
|
|
|
|
|
|
# test dataset, return raw text
|
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
@ -314,7 +315,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=SpeechCollator.from_config(config))
|
|
|
|
|
collate_fn=SpeechCollator.from_config(config),
|
|
|
|
|
num_workers=config.collator.num_workers, )
|
|
|
|
|
# return text token id
|
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
|
self.align_loader = DataLoader(
|
|
|
|
@ -322,7 +324,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=SpeechCollator.from_config(config))
|
|
|
|
|
collate_fn=SpeechCollator.from_config(config),
|
|
|
|
|
num_workers=config.collator.num_workers, )
|
|
|
|
|
logger.info("Setup train/valid/test/align Dataloader!")
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|