log valid loss, time dataset process

pull/578/head
Hui Zhang 5 years ago
parent 926b1876c7
commit b355b67f48

@ -46,7 +46,6 @@ class DeepSpeech2Trainer(Trainer):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
self.model.train()
start = time.time()
loss = self.model(*batch_data)
@ -100,6 +99,8 @@ class DeepSpeech2Trainer(Trainer):
self.visualizer.add_scalar("valid/{}".format(k), v,
self.iteration)
return valid_losses
def setup_model(self):
config = self.config
model = DeepSpeech2Model(

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import os
import cProfile
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
@ -48,4 +50,7 @@ if __name__ == "__main__":
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join('.', 'test.profile'))

@ -77,8 +77,6 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
self.model.train()
start = time.time()
loss, attention_loss, ctc_loss = self.model(*batch_data)
@ -134,6 +132,7 @@ class U2Trainer(Trainer):
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
self.model.train()
try:
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
@ -149,8 +148,8 @@ class U2Trainer(Trainer):
self.logger.error(e)
raise e
self.valid()
self.save()
valid_losses = self.valid()
self.save(infos=valid_losses)
self.new_epoch()
@mp_tools.rank_zero_only
@ -182,6 +181,7 @@ class U2Trainer(Trainer):
for k, v in valid_losses.items():
self.visualizer.add_scalar("valid/{}".format(k), v,
self.iteration)
return valid_losses
def setup_dataloader(self):
config = self.config.clone()

@ -290,19 +290,34 @@ class ManifestDataset(Dataset):
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
start_time = time.time()
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
load_wav_time = time.time() - start_time
logger.debug(f"load wav time: {load_wav_time}")
# audio augment
start_time = time.time()
self._augmentation_pipeline.transform_audio(speech_segment)
audio_aug_time = time.time() - start_time
logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time = time.time()
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
feature_time = time.time() - start_time
logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment
start_time = time.time()
specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time
logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part
def _instance_reader_creator(self, manifest):

@ -821,7 +821,8 @@ class U2Model(U2BaseModel):
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
paddle.to_tensor(mean, dtype=paddle.float),
paddle.to_tensor(istd, dtype=paddle.float))
else:
global_cmvn = None

@ -128,15 +128,15 @@ class Trainer():
dist.init_parallel_env()
@mp_tools.rank_zero_only
def save(self, tag=None, infos=None):
def save(self, tag=None, infos: dict=None):
"""Save checkpoint (model parameters and optimizer states).
"""
if infos is None:
infos = {
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr(),
}
infos = infos if infos else dict()
infos.update({
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
})
checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
@ -185,6 +185,7 @@ class Trainer():
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
self.model.train()
try:
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
@ -200,8 +201,8 @@ class Trainer():
self.logger.error(e)
raise e
self.valid()
self.save()
valid_losses = self.valid()
self.save(infos=valid_losses)
self.lr_scheduler.step()
self.new_epoch()

Loading…
Cancel
Save