pull/670/head
Hui Zhang 4 years ago
parent b98c7cd120
commit daf9abdaa2

@ -193,7 +193,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None):
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -215,7 +221,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
for utt, target, result in zip(utts, target_transcripts, result_transcripts):
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
@ -245,7 +252,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout)
metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']

@ -78,7 +78,8 @@ class U2Trainer(Trainer):
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len)
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
@ -121,7 +122,8 @@ class U2Trainer(Trainer):
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len)
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
@ -368,7 +370,13 @@ class U2Tester(U2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None):
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -395,7 +403,8 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts, result_transcripts):
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref

@ -347,6 +347,5 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx):
instance = self._manifest[idx]
feat, text =self.process_utterance(instance["feat"],
instance["text"])
feat, text = self.process_utterance(instance["feat"], instance["text"])
return instance["utt"], feat, text

@ -905,7 +905,6 @@ class U2InferModel(U2Model):
def __init__(self, configs: dict):
super().__init__(configs)
def forward(self,
feats,
feats_lengths,

Loading…
Cancel
Save