From daf9abdaa2109b64bfbe59fbfe03280ee98ef0a1 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 15 Jun 2021 09:30:20 +0000 Subject: [PATCH] format --- deepspeech/exps/deepspeech2/model.py | 14 +++++++++++--- deepspeech/exps/u2/model.py | 17 +++++++++++++---- deepspeech/io/dataset.py | 3 +-- deepspeech/models/u2.py | 1 - 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 468bc6521..e3a22463b 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -193,7 +193,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None): + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -215,7 +221,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - for utt, target, result in zip(utts, target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref @@ -245,7 +252,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch - metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout) + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 334d6bc8e..8fabd9ffd 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -78,7 +78,8 @@ class U2Trainer(Trainer): start = time.time() utt, audio, audio_len, text, text_len = batch_data - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad loss.backward() @@ -121,7 +122,8 @@ class U2Trainer(Trainer): total_loss = 0.0 for i, batch in enumerate(self.valid_loader): utt, audio, audio_len, text, text_len = batch - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) if paddle.isfinite(loss): num_utts = batch[1].shape[0] num_seen_utts += num_utts @@ -368,7 +370,13 @@ class U2Tester(U2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None): + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -395,7 +403,8 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for utt, target, result in zip(utts, target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 1cf3827d3..bd5f630d2 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -347,6 +347,5 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] - feat, text =self.process_utterance(instance["feat"], - instance["text"]) + feat, text = self.process_utterance(instance["feat"], instance["text"]) return instance["utt"], feat, text diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index bcfddaef0..238e2d35c 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -905,7 +905,6 @@ class U2InferModel(U2Model): def __init__(self, configs: dict): super().__init__(configs) - def forward(self, feats, feats_lengths,