add utt to train and test 0607

pull/657/head
Haoxin Ma 4 years ago
parent c8368410e2
commit f3c9f32c9a

@ -43,7 +43,8 @@ class DeepSpeech2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
start = time.time() start = time.time()
loss = self.model(*batch_data) utt, audio, audio_len, text, text_len = batch_data
loss = self.model(audio, audio_len, text, text_len)
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
self.optimizer.step() self.optimizer.step()
@ -73,7 +74,8 @@ class DeepSpeech2Trainer(Trainer):
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
loss = self.model(*batch) utt, audio, audio_len, text, text_len = batch
loss = self.model(audio, audio_len, text, text_len)
if paddle.isfinite(loss): if paddle.isfinite(loss):
num_utts = batch[1].shape[0] num_utts = batch[1].shape[0]
num_seen_utts += num_utts num_seen_utts += num_utts
@ -191,7 +193,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids])) trans.append(''.join([chr(i) for i in ids]))
return trans return trans
def compute_metrics(self, utt, audio, audio_len, texts, texts_len): def compute_metrics(self, audio, audio_len, texts, texts_len):
cfg = self.config.decoding cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -240,7 +242,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch) utt, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(audio, audio_len, texts, texts_len)
errors_sum += metrics['errors_sum'] errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']

@ -161,7 +161,7 @@ class DeepSpeech2Model(nn.Layer):
reduction=True, # sum reduction=True, # sum
batch_average=True) # sum / batch_size batch_average=True) # sum / batch_size
def forward(self, utt, audio, audio_len, text, text_len): def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss """Compute Model loss
Args: Args:

Loading…
Cancel
Save