Merge pull request #2532 from Zth9730/wav2vec2.0

[s2t] fix wav2vec2 report loss bug
pull/2542/head
Hui Zhang 2 years ago committed by GitHub
commit 964c22c677
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,6 +13,7 @@
# limitations under the License.
"""Contains wav2vec2 model."""
import json
import math
import os
import time
from collections import defaultdict
@ -46,25 +47,20 @@ logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self.avg_train_loss = 0
self.avg_train_loss = 0.0
def update_average(self, batch_index, loss, avg_loss):
def update_average(self, batch_index, loss):
"""Update running average of the loss.
Arguments
---------
batch_index : int
current batch index
loss : paddle.tensor
detached loss, a single float value.
avg_loss : float
current running average.
Returns
-------
avg_loss : float
The average loss.
"""
if paddle.isfinite(loss):
avg_loss -= avg_loss / (batch_index + 1)
avg_loss += float(loss) / (batch_index + 1)
return avg_loss
if math.isfinite(loss):
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
self.avg_train_loss += loss / (batch_index + 1)
def train_batch(self, batch_index, batch, msg):
train_conf = self.config
@ -80,8 +76,8 @@ class Wav2Vec2ASRTrainer(Trainer):
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
self.avg_train_loss = self.update_average(batch_index, loss,
self.avg_train_loss)
# update self.avg_train_loss
self.update_average(batch_index, float(loss))
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
@ -106,7 +102,7 @@ class Wav2Vec2ASRTrainer(Trainer):
self.lr_scheduler.step()
self.iteration += 1
losses_np = {'loss': float(self.avg_train_loss) * train_conf.accum_grad}
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
iteration_time = time.time() - start
for k, v in losses_np.items():
report(k, v)

Loading…
Cancel
Save