Merge pull request #945 from PaddlePaddle/ds2

ds2 exp with eval mode
pull/946/head
Jackwaterveg 3 years ago committed by GitHub
commit 00a50a0101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -167,6 +167,11 @@ class DeepSpeech2Trainer(Trainer):
logger.info(f"{model}") logger.info(f"{model}")
layer_tools.print_params(model, logger.info) layer_tools.print_params(model, logger.info)
self.model = model
logger.info("Setup model!")
if not self.train:
return
grad_clip = ClipGradByGlobalNormWithLog( grad_clip = ClipGradByGlobalNormWithLog(
config.training.global_grad_clip) config.training.global_grad_clip)
@ -180,74 +185,77 @@ class DeepSpeech2Trainer(Trainer):
weight_decay=paddle.regularizer.L2Decay( weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay), config.training.weight_decay),
grad_clip=grad_clip) grad_clip=grad_clip)
self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!") logger.info("Setup optimizer/lr_scheduler!")
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost() config.defrost()
config.collator.keep_transcription_text = False if self.train:
# train
config.data.manifest = config.data.train_manifest config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config) train_dataset = ManifestDataset.from_config(config)
if self.parallel:
config.data.manifest = config.data.dev_manifest batch_sampler = SortagradDistributedBatchSampler(
dev_dataset = ManifestDataset.from_config(config) train_dataset,
batch_size=config.collator.batch_size,
config.data.manifest = config.data.test_manifest num_replicas=None,
test_dataset = ManifestDataset.from_config(config) rank=None,
shuffle=True,
if self.parallel: drop_last=True,
batch_sampler = SortagradDistributedBatchSampler( sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
config.collator.keep_transcription_text = False
collate_fn_train = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=config.collator.batch_size, batch_sampler=batch_sampler,
num_replicas=None, collate_fn=collate_fn_train,
rank=None, num_workers=config.collator.num_workers)
shuffle=True,
drop_last=True, # dev
sortagrad=config.collator.sortagrad, config.data.manifest = config.data.dev_manifest
shuffle_method=config.collator.shuffle_method) dev_dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = False
collate_fn_dev = SpeechCollator.from_config(config)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=int(config.collator.batch_size),
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers)
logger.info("Setup train/valid Dataloader!")
else: else:
batch_sampler = SortagradBatchSampler( # test
train_dataset, config.data.manifest = config.data.test_manifest
shuffle=True, test_dataset = ManifestDataset.from_config(config)
batch_size=config.collator.batch_size,
drop_last=True, config.collator.augmentation_config = ""
sortagrad=config.collator.sortagrad, config.collator.keep_transcription_text = True
shuffle_method=config.collator.shuffle_method) collate_fn_test = SpeechCollator.from_config(config)
collate_fn_train = SpeechCollator.from_config(config) self.test_loader = DataLoader(
test_dataset,
config.collator.augmentation_config = "" batch_size=config.decoding.batch_size,
collate_fn_dev = SpeechCollator.from_config(config) shuffle=False,
drop_last=False,
config.collator.keep_transcription_text = True collate_fn=collate_fn_test,
config.collator.augmentation_config = "" num_workers=config.collator.num_workers)
collate_fn_test = SpeechCollator.from_config(config) logger.info("Setup test Dataloader!")
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=int(config.collator.batch_size),
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test,
num_workers=config.collator.num_workers)
logger.info("Setup train/valid/test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer): class DeepSpeech2Tester(DeepSpeech2Trainer):

@ -172,7 +172,7 @@ class U2Trainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts)) dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts return total_loss, num_seen_utts
def train(self): def do_train(self):
"""The training process control by step.""" """The training process control by step."""
# !!!IMPORTANT!!! # !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine

@ -173,7 +173,7 @@ class U2Trainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts)) dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts return total_loss, num_seen_utts
def train(self): def do_train(self):
"""The training process control by step.""" """The training process control by step."""
# !!!IMPORTANT!!! # !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine

@ -184,7 +184,7 @@ class U2STTrainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts)) dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts return total_loss, num_seen_utts
def train(self): def do_train(self):
"""The training process control by step.""" """The training process control by step."""
# !!!IMPORTANT!!! # !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine

@ -134,6 +134,10 @@ class Trainer():
logger.info( logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
@property
def train(self):
return self._train
@contextmanager @contextmanager
def eval(self): def eval(self):
self._train = False self._train = False
@ -248,7 +252,7 @@ class Trainer():
sys.exit( sys.exit(
f"Reach benchmark-max-step: {self.args.benchmark_max_step}") f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
def train(self): def do_train(self):
"""The training process control by epoch.""" """The training process control by epoch."""
self.before_train() self.before_train()
@ -321,7 +325,7 @@ class Trainer():
""" """
try: try:
with Timer("Training Done: {}"): with Timer("Training Done: {}"):
self.train() self.do_train()
except KeyboardInterrupt: except KeyboardInterrupt:
exit(-1) exit(-1)
finally: finally:
@ -432,7 +436,7 @@ class Trainer():
beginning of the experiment. beginning of the experiment.
""" """
config_file = self.config_dir / "config.yaml" config_file = self.config_dir / "config.yaml"
if self._train and config_file.exists(): if self.train and config_file.exists():
time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime()) time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime())
target_path = self.config_dir / ".".join( target_path = self.config_dir / ".".join(
[time_stamp, "config.yaml"]) [time_stamp, "config.yaml"])

Loading…
Cancel
Save