|
|
@ -167,6 +167,11 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"{model}")
|
|
|
|
logger.info(f"{model}")
|
|
|
|
layer_tools.print_params(model, logger.info)
|
|
|
|
layer_tools.print_params(model, logger.info)
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
logger.info("Setup model!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.train:
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
grad_clip = ClipGradByGlobalNormWithLog(
|
|
|
|
grad_clip = ClipGradByGlobalNormWithLog(
|
|
|
|
config.training.global_grad_clip)
|
|
|
|
config.training.global_grad_clip)
|
|
|
@ -180,26 +185,18 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
weight_decay=paddle.regularizer.L2Decay(
|
|
|
|
weight_decay=paddle.regularizer.L2Decay(
|
|
|
|
config.training.weight_decay),
|
|
|
|
config.training.weight_decay),
|
|
|
|
grad_clip=grad_clip)
|
|
|
|
grad_clip=grad_clip)
|
|
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.lr_scheduler = lr_scheduler
|
|
|
|
self.lr_scheduler = lr_scheduler
|
|
|
|
logger.info("Setup model/optimizer/lr_scheduler!")
|
|
|
|
logger.info("Setup optimizer/lr_scheduler!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_dataloader(self):
|
|
|
|
def setup_dataloader(self):
|
|
|
|
config = self.config.clone()
|
|
|
|
config = self.config.clone()
|
|
|
|
config.defrost()
|
|
|
|
config.defrost()
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
if self.train:
|
|
|
|
|
|
|
|
# train
|
|
|
|
config.data.manifest = config.data.train_manifest
|
|
|
|
config.data.manifest = config.data.train_manifest
|
|
|
|
train_dataset = ManifestDataset.from_config(config)
|
|
|
|
train_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
config.data.manifest = config.data.dev_manifest
|
|
|
|
|
|
|
|
dev_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
|
|
|
|
test_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
if self.parallel:
|
|
|
|
batch_sampler = SortagradDistributedBatchSampler(
|
|
|
|
batch_sampler = SortagradDistributedBatchSampler(
|
|
|
|
train_dataset,
|
|
|
|
train_dataset,
|
|
|
@ -219,20 +216,21 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
shuffle_method=config.collator.shuffle_method)
|
|
|
|
shuffle_method=config.collator.shuffle_method)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
collate_fn_train = SpeechCollator.from_config(config)
|
|
|
|
collate_fn_train = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
|
|
|
collate_fn_dev = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = True
|
|
|
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
|
|
|
collate_fn_test = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
train_dataset,
|
|
|
|
train_dataset,
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
collate_fn=collate_fn_train,
|
|
|
|
collate_fn=collate_fn_train,
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# dev
|
|
|
|
|
|
|
|
config.data.manifest = config.data.dev_manifest
|
|
|
|
|
|
|
|
dev_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = False
|
|
|
|
|
|
|
|
collate_fn_dev = SpeechCollator.from_config(config)
|
|
|
|
self.valid_loader = DataLoader(
|
|
|
|
self.valid_loader = DataLoader(
|
|
|
|
dev_dataset,
|
|
|
|
dev_dataset,
|
|
|
|
batch_size=int(config.collator.batch_size),
|
|
|
|
batch_size=int(config.collator.batch_size),
|
|
|
@ -240,6 +238,16 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
drop_last=False,
|
|
|
|
drop_last=False,
|
|
|
|
collate_fn=collate_fn_dev,
|
|
|
|
collate_fn=collate_fn_dev,
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# test
|
|
|
|
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
|
|
|
|
test_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.collator.augmentation_config = ""
|
|
|
|
|
|
|
|
config.collator.keep_transcription_text = True
|
|
|
|
|
|
|
|
collate_fn_test = SpeechCollator.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
test_dataset,
|
|
|
|
test_dataset,
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
@ -247,7 +255,7 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
drop_last=False,
|
|
|
|
drop_last=False,
|
|
|
|
collate_fn=collate_fn_test,
|
|
|
|
collate_fn=collate_fn_test,
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
logger.info("Setup train/valid/test Dataloader!")
|
|
|
|
logger.info("Setup test Dataloader!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|