|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""Contains wav2vec2 model."""
|
|
|
|
|
import json
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
@ -46,25 +47,20 @@ logger = Log(__name__).getlog()
|
|
|
|
|
class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
self.avg_train_loss = 0
|
|
|
|
|
self.avg_train_loss = 0.0
|
|
|
|
|
|
|
|
|
|
def update_average(self, batch_index, loss, avg_loss):
|
|
|
|
|
def update_average(self, batch_index, loss):
|
|
|
|
|
"""Update running average of the loss.
|
|
|
|
|
Arguments
|
|
|
|
|
---------
|
|
|
|
|
batch_index : int
|
|
|
|
|
current batch index
|
|
|
|
|
loss : paddle.tensor
|
|
|
|
|
detached loss, a single float value.
|
|
|
|
|
avg_loss : float
|
|
|
|
|
current running average.
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
avg_loss : float
|
|
|
|
|
The average loss.
|
|
|
|
|
"""
|
|
|
|
|
if paddle.isfinite(loss):
|
|
|
|
|
avg_loss -= avg_loss / (batch_index + 1)
|
|
|
|
|
avg_loss += float(loss) / (batch_index + 1)
|
|
|
|
|
return avg_loss
|
|
|
|
|
if math.isfinite(loss):
|
|
|
|
|
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
|
|
|
|
|
self.avg_train_loss += loss / (batch_index + 1)
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_index, batch, msg):
|
|
|
|
|
train_conf = self.config
|
|
|
|
@ -80,8 +76,8 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
|
# loss div by `batch_size * accum_grad`
|
|
|
|
|
loss /= train_conf.accum_grad
|
|
|
|
|
|
|
|
|
|
self.avg_train_loss = self.update_average(batch_index, loss,
|
|
|
|
|
self.avg_train_loss)
|
|
|
|
|
# update self.avg_train_loss
|
|
|
|
|
self.update_average(batch_index, float(loss))
|
|
|
|
|
|
|
|
|
|
# loss backward
|
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad != 0:
|
|
|
|
@ -106,7 +102,7 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
self.iteration += 1
|
|
|
|
|
|
|
|
|
|
losses_np = {'loss': float(self.avg_train_loss) * train_conf.accum_grad}
|
|
|
|
|
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
|
report(k, v)
|
|
|
|
|