log valid loss, time dataset process

pull/578/head
Hui Zhang 5 years ago
parent 926b1876c7
commit b355b67f48

@ -46,7 +46,6 @@ class DeepSpeech2Trainer(Trainer):
super().__init__(config, args) super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
self.model.train()
start = time.time() start = time.time()
loss = self.model(*batch_data) loss = self.model(*batch_data)
@ -100,6 +99,8 @@ class DeepSpeech2Trainer(Trainer):
self.visualizer.add_scalar("valid/{}".format(k), v, self.visualizer.add_scalar("valid/{}".format(k), v,
self.iteration) self.iteration)
return valid_losses
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model = DeepSpeech2Model( model = DeepSpeech2Model(

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import os
import cProfile
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
@ -48,4 +50,7 @@ if __name__ == "__main__":
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
print(config, file=f) print(config, file=f)
main(config, args) # Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join('.', 'test.profile'))

@ -77,8 +77,6 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training
self.model.train()
start = time.time() start = time.time()
loss, attention_loss, ctc_loss = self.model(*batch_data) loss, attention_loss, ctc_loss = self.model(*batch_data)
@ -134,6 +132,7 @@ class U2Trainer(Trainer):
self.logger.info( self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}") f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader): for batch_index, batch in enumerate(self.train_loader):
@ -149,8 +148,8 @@ class U2Trainer(Trainer):
self.logger.error(e) self.logger.error(e)
raise e raise e
self.valid() valid_losses = self.valid()
self.save() self.save(infos=valid_losses)
self.new_epoch() self.new_epoch()
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@ -182,6 +181,7 @@ class U2Trainer(Trainer):
for k, v in valid_losses.items(): for k, v in valid_losses.items():
self.visualizer.add_scalar("valid/{}".format(k), v, self.visualizer.add_scalar("valid/{}".format(k), v,
self.iteration) self.iteration)
return valid_losses
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()

@ -290,19 +290,34 @@ class ManifestDataset(Dataset):
where transcription part could be token ids or text. where transcription part could be token ids or text.
:rtype: tuple of (2darray, list) :rtype: tuple of (2darray, list)
""" """
start_time = time.time()
if isinstance(audio_file, str) and audio_file.startswith('tar:'): if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file( speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript) self._subfile_from_tar(audio_file), transcript)
else: else:
speech_segment = SpeechSegment.from_file(audio_file, transcript) speech_segment = SpeechSegment.from_file(audio_file, transcript)
load_wav_time = time.time() - start_time
logger.debug(f"load wav time: {load_wav_time}")
# audio augment # audio augment
start_time = time.time()
self._augmentation_pipeline.transform_audio(speech_segment) self._augmentation_pipeline.transform_audio(speech_segment)
audio_aug_time = time.time() - start_time
logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time = time.time()
specgram, transcript_part = self._speech_featurizer.featurize( specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text) speech_segment, self._keep_transcription_text)
if self._normalizer: if self._normalizer:
specgram = self._normalizer.apply(specgram) specgram = self._normalizer.apply(specgram)
feature_time = time.time() - start_time
logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment # specgram augment
start_time = time.time()
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time
logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part return specgram, transcript_part
def _instance_reader_creator(self, manifest): def _instance_reader_creator(self, manifest):

@ -821,7 +821,8 @@ class U2Model(U2BaseModel):
mean, istd = load_cmvn(configs['cmvn_file'], mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type']) configs['cmvn_file_type'])
global_cmvn = GlobalCMVN( global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) paddle.to_tensor(mean, dtype=paddle.float),
paddle.to_tensor(istd, dtype=paddle.float))
else: else:
global_cmvn = None global_cmvn = None

@ -128,15 +128,15 @@ class Trainer():
dist.init_parallel_env() dist.init_parallel_env()
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def save(self, tag=None, infos=None): def save(self, tag=None, infos: dict=None):
"""Save checkpoint (model parameters and optimizer states). """Save checkpoint (model parameters and optimizer states).
""" """
if infos is None: infos = infos if infos else dict()
infos = { infos.update({
"step": self.iteration, "step": self.iteration,
"epoch": self.epoch, "epoch": self.epoch,
"lr": self.optimizer.get_lr(), "lr": self.optimizer.get_lr()
} })
checkpoint.save_parameters(self.checkpoint_dir, self.iteration checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model, if tag is None else tag, self.model,
self.optimizer, infos) self.optimizer, infos)
@ -185,6 +185,7 @@ class Trainer():
self.logger.info( self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}") f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
self.model.train()
try: try:
data_start_time = time.time() data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader): for batch_index, batch in enumerate(self.train_loader):
@ -200,8 +201,8 @@ class Trainer():
self.logger.error(e) self.logger.error(e)
raise e raise e
self.valid() valid_losses = self.valid()
self.save() self.save(infos=valid_losses)
self.lr_scheduler.step() self.lr_scheduler.step()
self.new_epoch() self.new_epoch()

Loading…
Cancel
Save