fix wav2vec2 report loss bug

pull/2532/head
tianhao zhang 3 years ago
parent 49c0cf9e31
commit 86f65f0b8e

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains wav2vec2 model.""" """Contains wav2vec2 model."""
import json import json
import math
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
@ -46,25 +47,20 @@ logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer): class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.avg_train_loss = 0 self.avg_train_loss = 0.0
def update_average(self, batch_index, loss, avg_loss): def update_average(self, batch_index, loss):
"""Update running average of the loss. """Update running average of the loss.
Arguments Arguments
--------- ---------
batch_index : int
current batch index
loss : paddle.tensor loss : paddle.tensor
detached loss, a single float value. detached loss, a single float value.
avg_loss : float
current running average.
Returns
-------
avg_loss : float
The average loss.
""" """
if paddle.isfinite(loss): if math.isfinite(loss):
avg_loss -= avg_loss / (batch_index + 1) self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
avg_loss += float(loss) / (batch_index + 1) self.avg_train_loss += loss / (batch_index + 1)
return avg_loss
def train_batch(self, batch_index, batch, msg): def train_batch(self, batch_index, batch, msg):
train_conf = self.config train_conf = self.config
@ -80,8 +76,8 @@ class Wav2Vec2ASRTrainer(Trainer):
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
self.avg_train_loss = self.update_average(batch_index, loss, # update self.avg_train_loss
self.avg_train_loss) self.update_average(batch_index, float(loss))
# loss backward # loss backward
if (batch_index + 1) % train_conf.accum_grad != 0: if (batch_index + 1) % train_conf.accum_grad != 0:
@ -106,7 +102,7 @@ class Wav2Vec2ASRTrainer(Trainer):
self.lr_scheduler.step() self.lr_scheduler.step()
self.iteration += 1 self.iteration += 1
losses_np = {'loss': float(self.avg_train_loss) * train_conf.accum_grad} losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
iteration_time = time.time() - start iteration_time = time.time() - start
for k, v in losses_np.items(): for k, v in losses_np.items():
report(k, v) report(k, v)

Loading…
Cancel
Save