diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 05b55f75..ce8b56ac 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -43,7 +43,8 @@ class DeepSpeech2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): start = time.time() - loss = self.model(*batch_data) + utt, audio, audio_len, text, text_len = batch_data + loss = self.model(audio, audio_len, text, text_len) loss.backward() layer_tools.print_grads(self.model, print_func=None) self.optimizer.step() @@ -73,7 +74,8 @@ class DeepSpeech2Trainer(Trainer): num_seen_utts = 1 total_loss = 0.0 for i, batch in enumerate(self.valid_loader): - loss = self.model(*batch) + utt, audio, audio_len, text, text_len = batch + loss = self.model(audio, audio_len, text, text_len) if paddle.isfinite(loss): num_utts = batch[1].shape[0] num_seen_utts += num_utts @@ -191,7 +193,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utt, audio, audio_len, texts, texts_len): + def compute_metrics(self, audio, audio_len, texts, texts_len): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -240,7 +242,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): errors_sum, len_refs, num_ins = 0.0, 0, 0 for i, batch in enumerate(self.test_loader): - metrics = self.compute_metrics(*batch) + utt, audio, audio_len, texts, texts_len = batch + metrics = self.compute_metrics(audio, audio_len, texts, texts_len) errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index ab617a53..0ff5514d 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -161,7 +161,7 @@ class DeepSpeech2Model(nn.Layer): reduction=True, # sum batch_average=True) # sum / batch_size - def forward(self, utt, audio, audio_len, text, text_len): + def forward(self, audio, audio_len, text, text_len): """Compute Model loss Args: