|
|
|
@ -78,7 +78,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
start = time.time()
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch_data
|
|
|
|
|
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len)
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
|
|
|
|
|
text_len)
|
|
|
|
|
# loss div by `batch_size * accum_grad`
|
|
|
|
|
loss /= train_conf.accum_grad
|
|
|
|
|
loss.backward()
|
|
|
|
@ -121,7 +122,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
total_loss = 0.0
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len)
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
|
|
|
|
|
text_len)
|
|
|
|
|
if paddle.isfinite(loss):
|
|
|
|
|
num_utts = batch[1].shape[0]
|
|
|
|
|
num_seen_utts += num_utts
|
|
|
|
@ -368,7 +370,13 @@ class U2Tester(U2Trainer):
|
|
|
|
|
trans.append(''.join([chr(i) for i in ids]))
|
|
|
|
|
return trans
|
|
|
|
|
|
|
|
|
|
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None):
|
|
|
|
|
def compute_metrics(self,
|
|
|
|
|
utts,
|
|
|
|
|
audio,
|
|
|
|
|
audio_len,
|
|
|
|
|
texts,
|
|
|
|
|
texts_len,
|
|
|
|
|
fout=None):
|
|
|
|
|
cfg = self.config.decoding
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
@ -395,7 +403,8 @@ class U2Tester(U2Trainer):
|
|
|
|
|
simulate_streaming=cfg.simulate_streaming)
|
|
|
|
|
decode_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
for utt, target, result in zip(utts, target_transcripts, result_transcripts):
|
|
|
|
|
for utt, target, result in zip(utts, target_transcripts,
|
|
|
|
|
result_transcripts):
|
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
|
|
errors_sum += errors
|
|
|
|
|
len_refs += len_ref
|
|
|
|
|