Merge pull request #945 from PaddlePaddle/ds2

ds2 exp with eval mode
pull/946/head
Jackwaterveg 3 years ago committed by GitHub
commit 00a50a0101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -167,6 +167,11 @@ class DeepSpeech2Trainer(Trainer):
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
self.model = model
logger.info("Setup model!")
if not self.train:
return
grad_clip = ClipGradByGlobalNormWithLog(
config.training.global_grad_clip)
@ -180,74 +185,77 @@ class DeepSpeech2Trainer(Trainer):
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
logger.info("Setup optimizer/lr_scheduler!")
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
if self.train:
# train
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
config.collator.keep_transcription_text = False
collate_fn_train = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers)
# dev
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = False
collate_fn_dev = SpeechCollator.from_config(config)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=int(config.collator.batch_size),
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers)
logger.info("Setup train/valid Dataloader!")
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=int(config.collator.batch_size),
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test,
num_workers=config.collator.num_workers)
logger.info("Setup train/valid/test Dataloader!")
# test
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
collate_fn_test = SpeechCollator.from_config(config)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test,
num_workers=config.collator.num_workers)
logger.info("Setup test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer):

@ -172,7 +172,7 @@ class U2Trainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def train(self):
def do_train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine

@ -173,7 +173,7 @@ class U2Trainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def train(self):
def do_train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine

@ -184,7 +184,7 @@ class U2STTrainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def train(self):
def do_train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine

@ -134,6 +134,10 @@ class Trainer():
logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
@property
def train(self):
return self._train
@contextmanager
def eval(self):
self._train = False
@ -248,7 +252,7 @@ class Trainer():
sys.exit(
f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
def train(self):
def do_train(self):
"""The training process control by epoch."""
self.before_train()
@ -321,7 +325,7 @@ class Trainer():
"""
try:
with Timer("Training Done: {}"):
self.train()
self.do_train()
except KeyboardInterrupt:
exit(-1)
finally:
@ -432,7 +436,7 @@ class Trainer():
beginning of the experiment.
"""
config_file = self.config_dir / "config.yaml"
if self._train and config_file.exists():
if self.train and config_file.exists():
time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime())
target_path = self.config_dir / ".".join(
[time_stamp, "config.yaml"])

Loading…
Cancel
Save