Merge pull request #2532 from Zth9730/wav2vec2.0

[s2t] fix wav2vec2 report loss bug
pull/2542/head
Hui Zhang 3 years ago committed by GitHub
commit 964c22c677
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains wav2vec2 model.""" """Contains wav2vec2 model."""
import json import json
import math
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
@ -46,25 +47,20 @@ logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer): class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.avg_train_loss = 0 self.avg_train_loss = 0.0
def update_average(self, batch_index, loss, avg_loss): def update_average(self, batch_index, loss):
"""Update running average of the loss. """Update running average of the loss.
Arguments Arguments
--------- ---------
batch_index : int
current batch index
loss : paddle.tensor loss : paddle.tensor
detached loss, a single float value. detached loss, a single float value.
avg_loss : float
current running average.
Returns
-------
avg_loss : float
The average loss.
""" """
if paddle.isfinite(loss): if math.isfinite(loss):
avg_loss -= avg_loss / (batch_index + 1) self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
avg_loss += float(loss) / (batch_index + 1) self.avg_train_loss += loss / (batch_index + 1)
return avg_loss
def train_batch(self, batch_index, batch, msg): def train_batch(self, batch_index, batch, msg):
train_conf = self.config train_conf = self.config
@ -80,8 +76,8 @@ class Wav2Vec2ASRTrainer(Trainer):
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
self.avg_train_loss = self.update_average(batch_index, loss, # update self.avg_train_loss
self.avg_train_loss) self.update_average(batch_index, float(loss))
# loss backward # loss backward
if (batch_index + 1) % train_conf.accum_grad != 0: if (batch_index + 1) % train_conf.accum_grad != 0:
@ -106,7 +102,7 @@ class Wav2Vec2ASRTrainer(Trainer):
self.lr_scheduler.step() self.lr_scheduler.step()
self.iteration += 1 self.iteration += 1
losses_np = {'loss': float(self.avg_train_loss) * train_conf.accum_grad} losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
iteration_time = time.time() - start iteration_time = time.time() - start
for k, v in losses_np.items(): for k, v in losses_np.items():
report(k, v) report(k, v)

Loading…
Cancel
Save