Merge pull request #3167 from zxcd/amp

[ASR] add amp for U2 conformer
pull/3202/head
Hui Zhang 1 year ago committed by GitHub
commit 8371d14f5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -178,6 +178,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
- 👑 2023.04.25: Add [AMP for U2 conformer](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167).
- 🔥 2023.04.06: Add [subtitle file (.srt format) generation example](./demos/streaming_asr_server).
- 🔥 2023.03.14: Add SVS(Singing Voice Synthesis) examples with Opencpop dataset, including [DiffSinger](./examples/opencpop/svs1)、[PWGAN](./examples/opencpop/voc1) and [HiFiGAN](./examples/opencpop/voc5), the effect is continuously optimized.
- 👑 2023.03.09: Add [Wav2vec2ASR-zh](./examples/aishell/asr3).

@ -183,6 +183,7 @@
- 🧩 级联模型应用: 作为传统语音任务的扩展,我们结合了自然语言处理、计算机视觉等任务,实现更接近实际需求的产业级应用。
### 近期更新
- 👑 2023.04.25: 新增 [U2 conformer 的 AMP 训练](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167).
- 👑 2023.04.06: 新增 [srt格式字幕生成功能](./demos/streaming_asr_server)。
- 🔥 2023.03.14: 新增基于 Opencpop 数据集的 SVS (歌唱合成) 示例,包含 [DiffSinger](./examples/opencpop/svs1)、[PWGAN](./examples/opencpop/voc1) 和 [HiFiGAN](./examples/opencpop/voc5),效果持续优化中。
- 👑 2023.03.09: 新增 [Wav2vec2ASR-zh](./examples/aishell/asr3)。

@ -10,7 +10,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |-|
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |-|
[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz) | WenetSpeech Dataset | Char-based | 540 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |[FP32](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz) </br>[INT8](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz) |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.5.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.051968 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) | inference/python |-|

@ -13,15 +13,15 @@ paddlespeech version: 1.0.1
## Conformer Streaming
paddle version: 2.2.2
paddlespeech version: 0.2.0
paddlespeech version: 1.4.1
Need set `decoding.decoding_chunk_size=16` when decoding.
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention | 16, -1 | - | 0.0551 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 0.0629 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 0.0629 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.0544 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention | 16, -1 | - | 0.056102 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 0.058160 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 0.058160 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 |
## Transformer

@ -228,6 +228,16 @@ asr_dynamic_pretrained_models = {
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.5.0.model.tar.gz',
'md5':
'a0adb2b204902982718bc1d8917f7038',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
},
"transformer_librispeech-en-16k": {
'1.0': {

@ -23,6 +23,7 @@ import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.nn.utils import clip_grad_norm_
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
@ -47,14 +48,16 @@ class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
def train_batch(self, batch_index, batch_data, scaler, msg):
train_conf = self.config
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
with paddle.amp.auto_cast(
level=self.amp_level, enable=True if scaler else False):
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
@ -77,12 +80,26 @@ class U2Trainer(Trainer):
# processes.
context = nullcontext
with context():
loss.backward()
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
# do global grad clip
if train_conf.global_grad_clip != 0:
if scaler:
scaler.unscale_(self.optimizer)
# need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_(self.model.parameters(),
train_conf.global_grad_clip)
if scaler:
scaler.step(self.optimizer)
scaler.update()
else:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
@ -173,7 +190,8 @@ class U2Trainer(Trainer):
report("epoch", self.epoch)
report('step', self.iteration)
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.train_batch(batch_index, batch, self.scaler,
msg)
self.after_train_batch()
report('iter', batch_index + 1)
if not self.use_streamdata:
@ -253,6 +271,19 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.test_loader.vocab_size
model = U2Model.from_config(model_conf)
# For Mixed Precision Training
self.use_amp = self.config.get("use_amp", True)
self.amp_level = self.config.get("amp_level", "O1")
if self.train and self.use_amp:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.config.get(
"scale_loss", 32768.0)) #amp default num 32768.0
#Set amp_level
if self.amp_level == 'O2':
model = paddle.amp.decorate(models=model, level=self.amp_level)
else:
self.scaler = None
if self.parallel:
model = paddle.DataParallel(model)
@ -290,7 +321,6 @@ class U2Trainer(Trainer):
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
from collections import OrderedDict
@ -110,6 +111,7 @@ class Trainer():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._train = True
self.scaler = None
# print deps version
all_version()
@ -187,8 +189,13 @@ class Trainer():
infos.update({
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
"lr": self.optimizer.get_lr(),
})
if self.scaler:
scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
paddle.save(self.scaler.state_dict(), scaler_path)
self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
@ -211,6 +218,13 @@ class Trainer():
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
if os.path.exists(scaler_path):
scaler_state_dict = paddle.load(scaler_path)
self.scaler.load_state_dict(scaler_state_dict)
scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")

Loading…
Cancel
Save