diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index adc615321..fcc303a3a 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -46,7 +46,6 @@ class DeepSpeech2Trainer(Trainer): super().__init__(config, args) def train_batch(self, batch_index, batch_data, msg): - self.model.train() start = time.time() loss = self.model(*batch_data) @@ -100,6 +99,8 @@ class DeepSpeech2Trainer(Trainer): self.visualizer.add_scalar("valid/{}".format(k), v, self.iteration) + return valid_losses + def setup_model(self): config = self.config model = DeepSpeech2Model( diff --git a/deepspeech/exps/u2/bin/test.py b/deepspeech/exps/u2/bin/test.py index 7d69c392a..7d8e719cc 100644 --- a/deepspeech/exps/u2/bin/test.py +++ b/deepspeech/exps/u2/bin/test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Evaluation for U2 model.""" +import os +import cProfile from deepspeech.training.cli import default_argument_parser from deepspeech.utils.utility import print_arguments @@ -48,4 +50,7 @@ if __name__ == "__main__": with open(args.dump_config, 'w') as f: print(config, file=f) - main(config, args) + # Setting for profiling + pr = cProfile.Profile() + pr.runcall(main, config, args) + pr.dump_stats(os.path.join('.', 'test.profile')) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index a020e997b..d02818c20 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -77,8 +77,6 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training - self.model.train() - start = time.time() loss, attention_loss, ctc_loss = self.model(*batch_data) @@ -134,6 +132,7 @@ class U2Trainer(Trainer): self.logger.info( f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: + self.model.train() try: data_start_time = time.time() for batch_index, batch in enumerate(self.train_loader): @@ -149,8 +148,8 @@ class U2Trainer(Trainer): self.logger.error(e) raise e - self.valid() - self.save() + valid_losses = self.valid() + self.save(infos=valid_losses) self.new_epoch() @mp_tools.rank_zero_only @@ -182,6 +181,7 @@ class U2Trainer(Trainer): for k, v in valid_losses.items(): self.visualizer.add_scalar("valid/{}".format(k), v, self.iteration) + return valid_losses def setup_dataloader(self): config = self.config.clone() diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 2bd8ddb8a..7812cbdc2 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -290,19 +290,34 @@ class ManifestDataset(Dataset): where transcription part could be token ids or text. :rtype: tuple of (2darray, list) """ + start_time = time.time() if isinstance(audio_file, str) and audio_file.startswith('tar:'): speech_segment = SpeechSegment.from_file( self._subfile_from_tar(audio_file), transcript) else: speech_segment = SpeechSegment.from_file(audio_file, transcript) + load_wav_time = time.time() - start_time + logger.debug(f"load wav time: {load_wav_time}") + # audio augment + start_time = time.time() self._augmentation_pipeline.transform_audio(speech_segment) + audio_aug_time = time.time() - start_time + logger.debug(f"audio augmentation time: {audio_aug_time}") + + start_time = time.time() specgram, transcript_part = self._speech_featurizer.featurize( speech_segment, self._keep_transcription_text) if self._normalizer: specgram = self._normalizer.apply(specgram) + feature_time = time.time() - start_time + logger.debug(f"audio & test feature time: {feature_time}") + # specgram augment + start_time = time.time() specgram = self._augmentation_pipeline.transform_feature(specgram) + feature_aug_time = time.time() - start_time + logger.debug(f"audio feature augmentation time: {feature_aug_time}") return specgram, transcript_part def _instance_reader_creator(self, manifest): diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index a7b426356..af7a11e1c 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -821,7 +821,8 @@ class U2Model(U2BaseModel): mean, istd = load_cmvn(configs['cmvn_file'], configs['cmvn_file_type']) global_cmvn = GlobalCMVN( - paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) + paddle.to_tensor(mean, dtype=paddle.float), + paddle.to_tensor(istd, dtype=paddle.float)) else: global_cmvn = None diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index a2d0ee489..7a9748cf6 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -128,15 +128,15 @@ class Trainer(): dist.init_parallel_env() @mp_tools.rank_zero_only - def save(self, tag=None, infos=None): + def save(self, tag=None, infos: dict=None): """Save checkpoint (model parameters and optimizer states). """ - if infos is None: - infos = { - "step": self.iteration, - "epoch": self.epoch, - "lr": self.optimizer.get_lr(), - } + infos = infos if infos else dict() + infos.update({ + "step": self.iteration, + "epoch": self.epoch, + "lr": self.optimizer.get_lr() + }) checkpoint.save_parameters(self.checkpoint_dir, self.iteration if tag is None else tag, self.model, self.optimizer, infos) @@ -185,6 +185,7 @@ class Trainer(): self.logger.info( f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: + self.model.train() try: data_start_time = time.time() for batch_index, batch in enumerate(self.train_loader): @@ -200,8 +201,8 @@ class Trainer(): self.logger.error(e) raise e - self.valid() - self.save() + valid_losses = self.valid() + self.save(infos=valid_losses) self.lr_scheduler.step() self.new_epoch()