From 6f651d762ef3ca25529878fa60281c6fca178662 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 5 Jan 2022 09:49:21 +0000 Subject: [PATCH 1/3] fix batch sampler set_epoch when epcoh start --- paddlespeech/s2t/exps/u2/model.py | 8 ++++++-- paddlespeech/s2t/io/dataloader.py | 13 ++++++++----- paddlespeech/s2t/modules/ctc.py | 3 --- paddlespeech/s2t/training/scheduler.py | 9 +++++---- paddlespeech/s2t/training/trainer.py | 2 +- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index d0cea0316..db50a3615 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -240,7 +240,9 @@ class U2Trainer(Trainer): preprocess_conf=config.preprocess_config, n_iter_processes=config.num_workers, subsampling_factor=1, - num_encs=1) + num_encs=1, + dist_sampler=True, + shortest_first=False) self.valid_loader = BatchDataLoader( json_file=config.dev_manifest, @@ -259,7 +261,9 @@ class U2Trainer(Trainer): preprocess_conf=config.preprocess_config, n_iter_processes=config.num_workers, subsampling_factor=1, - num_encs=1) + num_encs=1, + dist_sampler=True, + shortest_first=False) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 455303f70..920de34fc 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -78,7 +78,8 @@ class BatchDataLoader(): load_aux_input: bool=False, load_aux_output: bool=False, num_encs: int=1, - dist_sampler: bool=False): + dist_sampler: bool=False, + shortest_first: bool=False): self.json_file = json_file self.train_mode = train_mode self.use_sortagrad = sortagrad == -1 or sortagrad > 0 @@ -97,6 +98,7 @@ class BatchDataLoader(): self.load_aux_input = load_aux_input self.load_aux_output = load_aux_output self.dist_sampler = dist_sampler + self.shortest_first = shortest_first # read json data with jsonlines.open(json_file, 'r') as reader: @@ -113,7 +115,7 @@ class BatchDataLoader(): maxlen_out, minibatches, # for debug min_batch_size=mini_batch_size, - shortest_first=self.use_sortagrad, + shortest_first=self.shortest_first or self.use_sortagrad, count=batch_count, batch_bins=batch_bins, batch_frames_in=batch_frames_in, @@ -149,13 +151,13 @@ class BatchDataLoader(): self.reader) if self.dist_sampler: - self.sampler = DistributedBatchSampler( + self.batch_sampler = DistributedBatchSampler( dataset=self.dataset, batch_size=1, shuffle=not self.use_sortagrad if self.train_mode else False, drop_last=False, ) else: - self.sampler = BatchSampler( + self.batch_sampler = BatchSampler( dataset=self.dataset, batch_size=1, shuffle=not self.use_sortagrad if self.train_mode else False, @@ -163,7 +165,7 @@ class BatchDataLoader(): self.dataloader = DataLoader( dataset=self.dataset, - batch_sampler=self.sampler, + batch_sampler=self.batch_sampler, collate_fn=batch_collate, num_workers=self.n_iter_processes, ) @@ -194,5 +196,6 @@ class BatchDataLoader(): echo += f"load_aux_input: {self.load_aux_input}, " echo += f"load_aux_output: {self.load_aux_output}, " echo += f"dist_sampler: {self.dist_sampler}, " + echo += f"shortest_first: {self.shortest_first}, " echo += f"file: {self.json_file}" return echo diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index ffc9f0387..4a2e4f24f 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -39,9 +39,6 @@ except ImportError: except Exception as e: logger.info("paddlespeech_ctcdecoders not installed!") -#try: -#except Exception as e: -# logger.info("ctcdecoder not installed!") __all__ = ['CTCDecoder'] diff --git a/paddlespeech/s2t/training/scheduler.py b/paddlespeech/s2t/training/scheduler.py index 0222246e8..b22f7ef85 100644 --- a/paddlespeech/s2t/training/scheduler.py +++ b/paddlespeech/s2t/training/scheduler.py @@ -67,18 +67,19 @@ class WarmupLR(LRScheduler): super().__init__(learning_rate, last_epoch, verbose) def __repr__(self): - return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, last_epoch={self.last_epoch})" def get_lr(self): + # self.last_epoch start from zero step_num = self.last_epoch + 1 return self.base_lr * self.warmup_steps**0.5 * min( step_num**-0.5, step_num * self.warmup_steps**-1.5) def set_step(self, step: int=None): ''' - It will update the learning rate in optimizer according to current ``epoch`` . + It will update the learning rate in optimizer according to current ``epoch`` . The new learning rate will take effect on next ``optimizer.step`` . - + Args: step (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. Returns: @@ -94,7 +95,7 @@ class ConstantLR(LRScheduler): learning_rate (float): The initial learning rate. It is a python float number. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . - + Returns: ``ConstantLR`` instance to schedule learning rate. """ diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index 4b2011eca..cac5e5704 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -222,7 +222,7 @@ class Trainer(): batch_sampler = self.train_loader.batch_sampler if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): logger.debug( - f"train_loader.batch_sample set epoch: {self.epoch}") + f"train_loader.batch_sample.set_epoch: {self.epoch}") batch_sampler.set_epoch(self.epoch) def before_train(self): From 3a2db414e67a024292898fba1b5bfab63aecc37f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 5 Jan 2022 09:51:13 +0000 Subject: [PATCH 2/3] format code --- .../s2t/exps/deepspeech2/bin/deploy/runtime.py | 2 +- .../s2t/exps/deepspeech2/bin/deploy/server.py | 2 +- paddlespeech/s2t/exps/deepspeech2/bin/export.py | 1 + .../s2t/exps/deepspeech2/bin/test_export.py | 1 + paddlespeech/s2t/exps/deepspeech2/bin/train.py | 2 +- paddlespeech/s2t/exps/u2/bin/export.py | 1 + paddlespeech/s2t/exps/u2/bin/train.py | 2 +- paddlespeech/s2t/exps/u2_kaldi/model.py | 1 + paddlespeech/s2t/exps/u2_st/bin/export.py | 1 + paddlespeech/s2t/exps/u2_st/model.py | 15 +++++++++------ paddlespeech/s2t/models/ds2/deepspeech2.py | 1 + paddlespeech/s2t/models/ds2_online/deepspeech2.py | 1 + paddlespeech/s2t/models/u2/u2.py | 1 + paddlespeech/s2t/models/u2_st/u2_st.py | 1 + paddlespeech/s2t/modules/ctc.py | 1 - 15 files changed, 22 insertions(+), 11 deletions(-) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py index ccb85906a..5755a5f10 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py @@ -19,8 +19,8 @@ import paddle from paddle.inference import Config from paddle.inference import create_predictor from paddle.io import DataLoader - from yacs.config import CfgNode + from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.models.ds2 import DeepSpeech2Model diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py index 85c2466f5..0d0b4f219 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py @@ -17,8 +17,8 @@ import functools import numpy as np import paddle from paddle.io import DataLoader - from yacs.config import CfgNode + from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.models.ds2 import DeepSpeech2Model diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/export.py b/paddlespeech/s2t/exps/deepspeech2/bin/export.py index 090b5fabf..ee013d79e 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/export.py @@ -13,6 +13,7 @@ # limitations under the License. """Export for DeepSpeech2 model.""" from yacs.config import CfgNode + from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py index 176028ed8..707eb9e1b 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py @@ -13,6 +13,7 @@ # limitations under the License. """Evaluation for DeepSpeech2 model.""" from yacs.config import CfgNode + from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index 5e8c0fffe..09e8662f1 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -13,8 +13,8 @@ # limitations under the License. """Trainer for DeepSpeech2 model.""" from paddle import distributed as dist - from yacs.config import CfgNode + from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments diff --git a/paddlespeech/s2t/exps/u2/bin/export.py b/paddlespeech/s2t/exps/u2/bin/export.py index 3907cebdd..592b12379 100644 --- a/paddlespeech/s2t/exps/u2/bin/export.py +++ b/paddlespeech/s2t/exps/u2/bin/export.py @@ -13,6 +13,7 @@ # limitations under the License. """Export for U2 model.""" from yacs.config import CfgNode + from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments diff --git a/paddlespeech/s2t/exps/u2/bin/train.py b/paddlespeech/s2t/exps/u2/bin/train.py index d562278f5..53c223283 100644 --- a/paddlespeech/s2t/exps/u2/bin/train.py +++ b/paddlespeech/s2t/exps/u2/bin/train.py @@ -16,8 +16,8 @@ import cProfile import os from paddle import distributed as dist - from yacs.config import CfgNode + from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index 780c5c081..d7a9f460b 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -42,6 +42,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig logger = Log(__name__).getlog() + class U2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) diff --git a/paddlespeech/s2t/exps/u2_st/bin/export.py b/paddlespeech/s2t/exps/u2_st/bin/export.py index 1bc4e1f3c..c641152fe 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/export.py +++ b/paddlespeech/s2t/exps/u2_st/bin/export.py @@ -13,6 +13,7 @@ # limitations under the License. """Export for U2 model.""" from yacs.config import CfgNode + from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index ca2c2c1da..ecb9a08ba 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -208,8 +208,7 @@ class U2STTrainer(Trainer): k.split(',')) == 2 else "" msg += "," msg = msg[:-1] # remove the last "," - if (batch_index + 1 - ) % self.config.log_interval == 0: + if (batch_index + 1) % self.config.log_interval == 0: logger.info(msg) except Exception as e: logger.error(e) @@ -260,7 +259,8 @@ class U2STTrainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False + preprocess_conf=config. + preprocess_config, # aug will be off when train_mode=False n_iter_processes=config.num_workers, subsampling_factor=1, load_aux_output=load_transcript, @@ -281,7 +281,8 @@ class U2STTrainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False + preprocess_conf=config. + preprocess_config, # aug will be off when train_mode=False n_iter_processes=config.num_workers, subsampling_factor=1, load_aux_output=load_transcript, @@ -290,7 +291,8 @@ class U2STTrainer(Trainer): logger.info("Setup train/valid Dataloader!") else: # test dataset, return raw text - decode_batch_size = config.get('decode',dict()).get('decode_batch_size', 1) + decode_batch_size = config.get('decode', dict()).get( + 'decode_batch_size', 1) self.test_loader = BatchDataLoader( json_file=config.test_manifest, train_mode=False, @@ -305,7 +307,8 @@ class U2STTrainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False + preprocess_conf=config. + preprocess_config, # aug will be off when train_mode=False n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, diff --git a/paddlespeech/s2t/models/ds2/deepspeech2.py b/paddlespeech/s2t/models/ds2/deepspeech2.py index ddc3612d9..15cadd38e 100644 --- a/paddlespeech/s2t/models/ds2/deepspeech2.py +++ b/paddlespeech/s2t/models/ds2/deepspeech2.py @@ -119,6 +119,7 @@ class DeepSpeech2Model(nn.Layer): before softmax) and a ctc cost layer. :rtype: tuple of LayerOutput """ + def __init__(self, feat_size, dict_size, diff --git a/paddlespeech/s2t/models/ds2_online/deepspeech2.py b/paddlespeech/s2t/models/ds2_online/deepspeech2.py index aae77f748..6451118a8 100644 --- a/paddlespeech/s2t/models/ds2_online/deepspeech2.py +++ b/paddlespeech/s2t/models/ds2_online/deepspeech2.py @@ -243,6 +243,7 @@ class DeepSpeech2ModelOnline(nn.Layer): before softmax) and a ctc cost layer. :rtype: tuple of LayerOutput """ + def __init__( self, feat_size, diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 26e81acf6..dc3072c62 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -59,6 +59,7 @@ logger = Log(__name__).getlog() class U2BaseModel(ASRInterface, nn.Layer): """CTC-Attention hybrid Encoder-Decoder model""" + def __init__(self, vocab_size: int, encoder: TransformerEncoder, diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 1c5596bac..bcade95ee 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -51,6 +51,7 @@ logger = Log(__name__).getlog() class U2STBaseModel(nn.Layer): """CTC-Attention hybrid Encoder-Decoder model""" + def __init__(self, vocab_size: int, encoder: TransformerEncoder, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 4a2e4f24f..6e9655799 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -39,7 +39,6 @@ except ImportError: except Exception as e: logger.info("paddlespeech_ctcdecoders not installed!") - __all__ = ['CTCDecoder'] From 45832f6770ea54b87606e0f663b426af69caa05e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 5 Jan 2022 09:56:18 +0000 Subject: [PATCH 3/3] fix default dist_samlper to False --- paddlespeech/s2t/exps/u2/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index db50a3615..710f3b62e 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -241,7 +241,7 @@ class U2Trainer(Trainer): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=True, + dist_sampler=False, shortest_first=False) self.valid_loader = BatchDataLoader( @@ -262,7 +262,7 @@ class U2Trainer(Trainer): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=True, + dist_sampler=False, shortest_first=False) logger.info("Setup train/valid Dataloader!") else: