Merge pull request #945 from PaddlePaddle/ds2

ds2 exp with eval mode
pull/946/head
Jackwaterveg 3 years ago committed by GitHub
commit 00a50a0101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -167,6 +167,11 @@ class DeepSpeech2Trainer(Trainer):
logger.info(f"{model}") logger.info(f"{model}")
layer_tools.print_params(model, logger.info) layer_tools.print_params(model, logger.info)
self.model = model
logger.info("Setup model!")
if not self.train:
return
grad_clip = ClipGradByGlobalNormWithLog( grad_clip = ClipGradByGlobalNormWithLog(
config.training.global_grad_clip) config.training.global_grad_clip)
@ -180,26 +185,18 @@ class DeepSpeech2Trainer(Trainer):
weight_decay=paddle.regularizer.L2Decay( weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay), config.training.weight_decay),
grad_clip=grad_clip) grad_clip=grad_clip)
self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!") logger.info("Setup optimizer/lr_scheduler!")
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost() config.defrost()
config.collator.keep_transcription_text = False if self.train:
# train
config.data.manifest = config.data.train_manifest config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config) train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
@ -219,20 +216,21 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.collator.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method) shuffle_method=config.collator.shuffle_method)
config.collator.keep_transcription_text = False
collate_fn_train = SpeechCollator.from_config(config) collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn_train, collate_fn=collate_fn_train,
num_workers=config.collator.num_workers) num_workers=config.collator.num_workers)
# dev
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = False
collate_fn_dev = SpeechCollator.from_config(config)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=int(config.collator.batch_size), batch_size=int(config.collator.batch_size),
@ -240,6 +238,16 @@ class DeepSpeech2Trainer(Trainer):
drop_last=False, drop_last=False,
collate_fn=collate_fn_dev, collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers) num_workers=config.collator.num_workers)
logger.info("Setup train/valid Dataloader!")
else:
# test
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
config.collator.augmentation_config = ""
config.collator.keep_transcription_text = True
collate_fn_test = SpeechCollator.from_config(config)
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
@ -247,7 +255,7 @@ class DeepSpeech2Trainer(Trainer):
drop_last=False, drop_last=False,
collate_fn=collate_fn_test, collate_fn=collate_fn_test,
num_workers=config.collator.num_workers) num_workers=config.collator.num_workers)
logger.info("Setup train/valid/test Dataloader!") logger.info("Setup test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer): class DeepSpeech2Tester(DeepSpeech2Trainer):

@ -172,7 +172,7 @@ class U2Trainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts)) dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts return total_loss, num_seen_utts
def train(self): def do_train(self):
"""The training process control by step.""" """The training process control by step."""
# !!!IMPORTANT!!! # !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine

@ -173,7 +173,7 @@ class U2Trainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts)) dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts return total_loss, num_seen_utts
def train(self): def do_train(self):
"""The training process control by step.""" """The training process control by step."""
# !!!IMPORTANT!!! # !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine

@ -184,7 +184,7 @@ class U2STTrainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts)) dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts return total_loss, num_seen_utts
def train(self): def do_train(self):
"""The training process control by step.""" """The training process control by step."""
# !!!IMPORTANT!!! # !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine

@ -134,6 +134,10 @@ class Trainer():
logger.info( logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
@property
def train(self):
return self._train
@contextmanager @contextmanager
def eval(self): def eval(self):
self._train = False self._train = False
@ -248,7 +252,7 @@ class Trainer():
sys.exit( sys.exit(
f"Reach benchmark-max-step: {self.args.benchmark_max_step}") f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
def train(self): def do_train(self):
"""The training process control by epoch.""" """The training process control by epoch."""
self.before_train() self.before_train()
@ -321,7 +325,7 @@ class Trainer():
""" """
try: try:
with Timer("Training Done: {}"): with Timer("Training Done: {}"):
self.train() self.do_train()
except KeyboardInterrupt: except KeyboardInterrupt:
exit(-1) exit(-1)
finally: finally:
@ -432,7 +436,7 @@ class Trainer():
beginning of the experiment. beginning of the experiment.
""" """
config_file = self.config_dir / "config.yaml" config_file = self.config_dir / "config.yaml"
if self._train and config_file.exists(): if self.train and config_file.exists():
time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime()) time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime())
target_path = self.config_dir / ".".join( target_path = self.config_dir / ".".join(
[time_stamp, "config.yaml"]) [time_stamp, "config.yaml"])

Loading…
Cancel
Save