freeze wav2vec2=True, change loss report and update README.md

pull/2518/head
tianhao zhang 2 years ago
parent 3d994f5c23
commit 2ae94bd277

@ -88,6 +88,12 @@ data/
|-- test.meta |-- test.meta
`-- train.meta `-- train.meta
``` ```
Stage 0 also downloads the pre-trained [wav2vec2](https://paddlespeech.bj.bcebos.com/wav2vec/wav2vec2-large-960h-lv60-self.pdparams) model.
```bash
mkdir -p exp/wav2vec2
wget -P exp/wav2vec2 https://paddlespeech.bj.bcebos.com/wav2vec/wav2vec2-large-960h-lv60-self.pdparams
```
## Stage 1: Model Training ## Stage 1: Model Training
If you want to train the model. you can use stage 1 in `run.sh`. The code is shown below. If you want to train the model. you can use stage 1 in `run.sh`. The code is shown below.
```bash ```bash

@ -1,7 +1,7 @@
############################################ ############################################
# Network Architecture # # Network Architecture #
############################################ ############################################
freeze_wav2vec2: False freeze_wav2vec2: True
normalize_wav: True normalize_wav: True
output_norm: True output_norm: True
dnn_blocks: 2 dnn_blocks: 2

@ -48,6 +48,24 @@ class Wav2Vec2ASRTrainer(Trainer):
super().__init__(config, args) super().__init__(config, args)
self.avg_train_loss = 0 self.avg_train_loss = 0
def update_average(self, batch_index, loss, avg_loss):
"""Update running average of the loss.
Arguments
---------
loss : paddle.tensor
detached loss, a single float value.
avg_loss : float
current running average.
Returns
-------
avg_loss : float
The average loss.
"""
if paddle.isfinite(loss):
avg_loss -= avg_loss / (batch_index + 1)
avg_loss += float(loss) / (batch_index + 1)
return avg_loss
def train_batch(self, batch_index, batch, msg): def train_batch(self, batch_index, batch, msg):
train_conf = self.config train_conf = self.config
start = time.time() start = time.time()
@ -59,11 +77,11 @@ class Wav2Vec2ASRTrainer(Trainer):
wav = wav[:, :, 0] wav = wav[:, :, 0]
wav = self.speech_augmentation(wav, wavs_lens_rate) wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# pring(wav, wavs_lens_rate, target, target_lens_rate)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
losses_np = {'loss': float(loss) * train_conf.accum_grad} self.avg_train_loss = self.update_average(batch_index, loss,
self.avg_train_loss)
# loss backward # loss backward
if (batch_index + 1) % train_conf.accum_grad != 0: if (batch_index + 1) % train_conf.accum_grad != 0:
@ -87,6 +105,8 @@ class Wav2Vec2ASRTrainer(Trainer):
self.optimizer.clear_grad() self.optimizer.clear_grad()
self.lr_scheduler.step() self.lr_scheduler.step()
self.iteration += 1 self.iteration += 1
losses_np = {'loss': float(self.avg_train_loss) * train_conf.accum_grad}
iteration_time = time.time() - start iteration_time = time.time() - start
for k, v in losses_np.items(): for k, v in losses_np.items():
report(k, v) report(k, v)

Loading…
Cancel
Save