|
|
|
@ -43,7 +43,8 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_index, batch_data, msg):
|
|
|
|
|
start = time.time()
|
|
|
|
|
loss = self.model(*batch_data)
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch_data
|
|
|
|
|
loss = self.model(audio, audio_len, text, text_len)
|
|
|
|
|
loss.backward()
|
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
|
self.optimizer.step()
|
|
|
|
@ -73,9 +74,10 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
num_seen_utts = 1
|
|
|
|
|
total_loss = 0.0
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
loss = self.model(*batch)
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch
|
|
|
|
|
loss = self.model(audio, audio_len, text, text_len)
|
|
|
|
|
if paddle.isfinite(loss):
|
|
|
|
|
num_utts = batch[0].shape[0]
|
|
|
|
|
num_utts = batch[1].shape[0]
|
|
|
|
|
num_seen_utts += num_utts
|
|
|
|
|
total_loss += float(loss) * num_utts
|
|
|
|
|
valid_losses['val_loss'].append(float(loss))
|
|
|
|
@ -191,7 +193,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
trans.append(''.join([chr(i) for i in ids]))
|
|
|
|
|
return trans
|
|
|
|
|
|
|
|
|
|
def compute_metrics(self, audio, audio_len, texts, texts_len):
|
|
|
|
|
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None):
|
|
|
|
|
cfg = self.config.decoding
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
@ -213,11 +215,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
|
for utt, target, result in zip(utts, target_transcripts, result_transcripts):
|
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
|
|
errors_sum += errors
|
|
|
|
|
len_refs += len_ref
|
|
|
|
|
num_ins += 1
|
|
|
|
|
if fout:
|
|
|
|
|
fout.write(utt + " " + result + "\n")
|
|
|
|
|
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
|
|
|
|
(target, result))
|
|
|
|
|
logger.info("Current error rate [%s] = %f" %
|
|
|
|
@ -238,9 +242,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
|
cfg = self.config
|
|
|
|
|
error_rate_type = None
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
|
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
|
metrics = self.compute_metrics(*batch)
|
|
|
|
|
utts, audio, audio_len, texts, texts_len = batch
|
|
|
|
|
metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout)
|
|
|
|
|
errors_sum += metrics['errors_sum']
|
|
|
|
|
len_refs += metrics['len_refs']
|
|
|
|
|
num_ins += metrics['num_ins']
|
|
|
|
|