|
|
|
@ -317,7 +317,7 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
use_dB_normalization=config.data.use_dB_normalization,
|
|
|
|
|
target_dB=config.data.target_dB,
|
|
|
|
|
random_seed=config.data.random_seed,
|
|
|
|
|
keep_transcription_text=True)
|
|
|
|
|
keep_transcription_text=False)
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
|
batch_sampler = DeepSpeech2DistributedBatchSampler(
|
|
|
|
@ -342,14 +342,14 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
|
train_dataset,
|
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
|
collate_fn=SpeechCollator(is_training=True),
|
|
|
|
|
collate_fn=collate_fn,
|
|
|
|
|
num_workers=config.data.num_workers, )
|
|
|
|
|
self.valid_loader = DataLoader(
|
|
|
|
|
dev_dataset,
|
|
|
|
|
batch_size=config.data.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=SpeechCollator(is_training=True))
|
|
|
|
|
collate_fn=collate_fn)
|
|
|
|
|
self.logger.info("Setup train/valid Dataloader!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -415,7 +415,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
|
self.model.eval()
|
|
|
|
|
losses = defaultdict(list)
|
|
|
|
|
|
|
|
|
|
cfg = self.config
|
|
|
|
|
# decoders only accept string encoded in utf-8
|
|
|
|
|
vocab_list = self.test_loader.dataset.vocab_list
|
|
|
|
@ -432,10 +432,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
|
audio, text, audio_len, text_len = batch
|
|
|
|
|
outputs = self.model.predict(audio, audio_len)
|
|
|
|
|
loss = self.compute_losses(batch, outputs)
|
|
|
|
|
losses['test_loss'].append(float(loss))
|
|
|
|
|
|
|
|
|
|
metrics = self.compute_metrics(batch, outputs)
|
|
|
|
|
|
|
|
|
|
errors_sum += metrics['errors_sum']
|
|
|
|
|
len_refs += metrics['len_refs']
|
|
|
|
|
num_ins += metrics['num_ins']
|
|
|
|
@ -443,14 +441,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
self.logger.info("Error rate [%s] (%d/?) = %f" %
|
|
|
|
|
(error_rate_type, num_ins, errors_sum / len_refs))
|
|
|
|
|
|
|
|
|
|
# write visual log
|
|
|
|
|
losses = {k: np.mean(v) for k, v in losses.items()}
|
|
|
|
|
|
|
|
|
|
# logging
|
|
|
|
|
msg = "Test: "
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items())
|
|
|
|
|
msg += ", Final error rate [%s] (%d/%d) = %f" % (
|
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
|
self.logger.info(msg)
|
|
|
|
|