From 64f177cc6b802e724323bcb79065c2045cfdfdec Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 12 Apr 2021 08:22:05 +0000 Subject: [PATCH] add u2 bins --- .pre-commit-config.yaml | 12 +- deepspeech/exps/u2/__init__.py | 13 + deepspeech/exps/u2/bin/export.py | 58 +++ deepspeech/exps/u2/bin/test.py | 59 +++ deepspeech/exps/u2/bin/train.py | 60 +++ deepspeech/exps/u2/config.py | 40 ++ deepspeech/exps/u2/model.py | 432 ++++++++++++++++++ deepspeech/frontend/augmentor/augmentation.py | 2 +- deepspeech/io/dataset.py | 143 ++++-- deepspeech/models/u2.py | 107 ++++- deepspeech/training/trainer.py | 2 + deepspeech/utils/socket_server.py | 1 + examples/tiny/s1/conf/conformer.yaml | 177 +++---- examples/tiny/s1/local/train.sh | 18 + requirements.txt | 11 +- 15 files changed, 992 insertions(+), 143 deletions(-) create mode 100644 deepspeech/exps/u2/__init__.py create mode 100644 deepspeech/exps/u2/bin/export.py create mode 100644 deepspeech/exps/u2/bin/test.py create mode 100644 deepspeech/exps/u2/bin/train.py create mode 100644 deepspeech/exps/u2/config.py create mode 100644 deepspeech/exps/u2/model.py create mode 100644 examples/tiny/s1/local/train.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 072225ccb..13a9daaa5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,17 @@ files: \.md$ - id: trailing-whitespace files: \.md$ -- repo: https://github.com/Lucas-C/pre-commit-hooks + - id: requirements-txt-fixer + - id: check-yaml + - id: check-json + - id: pretty-format-json + - id: check-merge-conflict + - id: flake8 + aergs: + - --ignore=E501,E228,E226,E261,E266,E128,E402,W503 + - --builtins=G,request + - --jobs=1 +- repo : https://github.com/Lucas-C/pre-commit-hooks sha: v1.0.1 hooks: - id: forbid-crlf diff --git a/deepspeech/exps/u2/__init__.py b/deepspeech/exps/u2/__init__.py new file mode 100644 index 000000000..185a92b8d --- /dev/null +++ b/deepspeech/exps/u2/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/exps/u2/bin/export.py b/deepspeech/exps/u2/bin/export.py new file mode 100644 index 000000000..f9e9eb210 --- /dev/null +++ b/deepspeech/exps/u2/bin/export.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export for U2 model.""" + +import io +import logging +import argparse +import functools + +from paddle import distributed as dist + +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments +from deepspeech.utils.error_rate import char_errors, word_errors + +from deepspeech.exps.u2.config import get_cfg_defaults +from deepspeech.exps.u2.model import U2Tester as Tester + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_export() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/u2/bin/test.py b/deepspeech/exps/u2/bin/test.py new file mode 100644 index 000000000..068822962 --- /dev/null +++ b/deepspeech/exps/u2/bin/test.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for U2 model.""" + +import io +import logging +import argparse +import functools + +from paddle import distributed as dist + +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments +from deepspeech.utils.error_rate import char_errors, word_errors + +# TODO(hui zhang): dynamic load +from deepspeech.exps.u2.config import get_cfg_defaults +from deepspeech.exps.u2.model import U2Tester as Tester + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py new file mode 100644 index 000000000..2742d94d8 --- /dev/null +++ b/deepspeech/exps/u2/bin/train.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer for U2 model.""" + +import io +import logging +import argparse +import functools + +from paddle import distributed as dist + +from deepspeech.utils.utility import print_arguments +from deepspeech.training.cli import default_argument_parser + +from deepspeech.exps.u2.config import get_cfg_defaults +from deepspeech.exps.u2.model import U2Trainer as Trainer + + +def main_sp(config, args): + exp = Trainer(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.device == "gpu" and args.nprocs > 1: + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/u2/config.py b/deepspeech/exps/u2/config.py new file mode 100644 index 000000000..48ec05efb --- /dev/null +++ b/deepspeech/exps/u2/config.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from yacs.config import CfgNode + +from deepspeech.models.u2 import U2Model +from deepspeech.exps.u2.model import U2Trainer +from deepspeech.exps.u2.model import U2Tester + +_C = CfgNode() + +_C.data = CfgNode() +ManifestDataset.params(_C.data) + +_C.model = CfgNode() +U2Model.params(_C.model) + +_C.training = CfgNode() +U2Trainer.params(_C.training) + +_C.decoding = CfgNode() +U2Tester.params(_C.training) + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py new file mode 100644 index 000000000..9d9f9961a --- /dev/null +++ b/deepspeech/exps/u2/model.py @@ -0,0 +1,432 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains U2 model.""" + +import io +import sys +import os +import time +import logging +import numpy as np +from collections import defaultdict +from functools import partial +from pathlib import Path + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader + +from deepspeech.training import Trainer +from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.training.scheduler import WarmupLR + +from deepspeech.utils import mp_tools +from deepspeech.utils import layer_tools +from deepspeech.utils import error_rate + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.dataset import ManifestDataset + +from deepspeech.modules.loss import CTCLoss + +from deepspeech.models.u2 import U2Model + +logger = logging.getLogger(__name__) + + +class U2Trainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + n_epoch=50, # train epochs + log_interval=100, # steps + accum_grad=1, # accum grad by # steps + global_grad_clip=5.0, # the global norm clip + )) + default.optim = 'adam' + default.optim_conf = CfgNode( + dict( + lr=5e-4, # learning rate + weight_decay=1e-6, # the coeff of weight decay + )) + default.scheduler = 'warmuplr' + default.scheduler_conf = CfgNode( + dict( + warmup_steps=25000, + lr_decay=1.0, # learning rate decay + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def train_batch(self, batch_data): + train_conf = self.config.training + self.model.train() + + start = time.time() + loss = self.model(*batch_data) + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + if self.iteration % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + + iteration_time = time.time() - start + + losses_np = { + 'train_loss': float(loss), + 'train_loss_div_batchsize': + float(loss) / self.config.data.batch_size + } + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "time: {:>.3f}s, ".format(iteration_time) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_np.items()) + if self.iteration % train_conf.log_interval == 0: + self.logger.info(msg) + + if dist.get_rank() == 0 and self.visualizer: + for k, v in losses_np.items(): + self.visualizer.add_scalar("train/{}".format(k), v, + self.iteration) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def valid(self): + self.model.eval() + self.logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") + valid_losses = defaultdict(list) + for i, batch in enumerate(self.valid_loader): + loss = self.model(*batch) + + valid_losses['val_loss'].append(float(loss)) + valid_losses['val_loss_div_batchsize'].append( + float(loss) / self.config.data.batch_size) + + # write visual log + valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_losses.items()) + self.logger.info(msg) + + if self.visualizer: + for k, v in valid_losses.items(): + self.visualizer.add_scalar("valid/{}".format(k), v, + self.iteration) + + def setup_dataloader(self): + config = self.config.clone() + config.data.keep_transcription_text = False + + # train/valid dataset, return token ids + config.data.manfiest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + + config.data.manfiest = config.data.dev_manifest + config.data.augmentation_config = "" + dev_dataset = ManifestDataset.from_config(config) + + collate_fn = SpeechCollator(keep_transcription_text=False) + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.data.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.data.sortagrad, + shuffle_method=config.data.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.data.batch_size, + drop_last=True, + sortagrad=config.data.sortagrad, + shuffle_method=config.data.shuffle_method) + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=config.data.num_workers, ) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.data.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn) + + # test dataset, return raw text + config.data.keep_transcription_text = True + config.data.augmentation_config = "" + config.data.manfiest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + # return text ord id + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator(keep_transcription_text=True)) + self.logger.info("Setup train/valid/test Dataloader!") + + def setup_model(self): + config = self.config.clone() + model_conf = config.model + model_conf.input_dim = self.train_loader.dataset.feature_size + model_conf.output_dim = self.train_loader.dataset.vocab_size + model = U2Model.from_config(model_conf) + + if self.parallel: + model = paddle.DataParallel(model) + + layer_tools.print_params(model, self.logger.info) + + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.train_config + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + + grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip) + weight_decay = paddle.regularizer.L2Decay(train_config.weight_decay) + + if scheduler_type == 'expdecaylr': + lr_scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=optim_conf.lr, + gamma=scheduler_conf.lr_decay, + verbose=True) + elif scheduler_type == 'warmuplr': + lr_scheduler = WarmupLR( + learning_rate=optim_conf.lr, + warmup_steps=scheduler_conf.warmup_steps, + verbose=True) + else: + raise ValueError(f"Not support scheduler: {scheduler_type}") + + if optim_type == 'adam': + optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=weight_decay, + grad_clip=grad_clip) + else: + raise ValueError(f"Not support optim: {optim_type}") + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.logger.info("Setup model/optimizer/lr_scheduler!") + + +class U2Tester(U2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # decoding config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def ordid2token(self, texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([chr(i) for i in ids])) + return trans + + def compute_metrics(self, audio, texts, audio_len, texts_len): + cfg = self.config.decoding + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer + + vocab_list = self.test_loader.dataset.vocab_list + + target_transcripts = self.ordid2token(texts, texts_len) + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + + for target, result in zip(target_transcripts, result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + self.logger.info( + "\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + self.logger.info("Current error rate [%s] = %f" % ( + cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + self.model.eval() + self.logger.info( + f"Test Total Examples: {len(self.test_loader.dataset)}") + + error_rate_type = None + errors_sum, len_refs, num_ins = 0.0, 0, 0 + + for i, batch in enumerate(self.test_loader): + metrics = self.compute_metrics(*batch) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + self.logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) + + # logging + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += ", Final error rate [%s] (%d/%d) = %f" % ( + error_rate_type, num_ins, num_ins, errors_sum / len_refs) + self.logger.info(msg) + + def run_test(self): + self.resume_or_load() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def export(self): + from deepspeech.models.u2 import U2InferModel + infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, + self.config.model.clone(), + self.args.checkpoint_path) + infer_model.eval() + feat_dim = self.test_loader.dataset.feature_size + static_model = paddle.jit.to_static( + infer_model, + input_spec=[ + paddle.static.InputSpec( + shape=[None, feat_dim, None], + dtype='float32'), # audio, [B,D,T] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + logger.info(f"Export code: {static_model.forward.code}") + paddle.jit.save(static_model, self.args.export_path) + + def run_export(self): + try: + self.export() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + + self.setup_output_dir() + self.setup_checkpointer() + self.setup_logger() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + def setup_logger(self): + """Initialize a text logger to log the experiment. + + Each process has its own text logger. The logging message is write to + the standard output and a text file named ``worker_n.log`` in the + output directory, where ``n`` means the rank of the process. + """ + format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' + formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') + + logger.setLevel("INFO") + + # global logger + stdout = True + save_path = "" + logging.basicConfig( + level=logging.DEBUG if stdout else logging.INFO, + format=format, + datefmt='%Y/%m/%d %H:%M:%S', + filename=save_path if not stdout else None) + self.logger = logger diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index e50084a00..6c5d76ba9 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -83,7 +83,7 @@ class AugmentationPipeline(): :raises ValueError: If the augmentation json config is in incorrect format". """ - def __init__(self, augmentation_config, random_seed=0): + def __init__(self, augmentation_config: str, random_seed=0): self._rng = random.Random(random_seed) self._augmentors, self._rates = self._parse_pipeline_from( augmentation_config) diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index b950029b4..3db407dcf 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -20,6 +20,7 @@ import logging import numpy as np from collections import namedtuple from functools import partial +from yacs.config import CfgNode from paddle.io import Dataset @@ -37,6 +38,97 @@ __all__ = [ class ManifestDataset(Dataset): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + train_manifest="", + dev_manifest="", + test_manifest="", + unit_type="char", + vocab_filepath="", + spm_model_prefix="", + mean_std_filepath="", + augmentation_config="", + max_input_len=27.0, + min_input_len=0.0, + max_output_len=float('inf'), + min_output_len=0.0, + max_output_input_ratio=float('inf'), + min_output_input_ratio=0.0, + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + raw_wav=True, # use raw_wav or kaldi feature + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delat_delta=False, # 'mfcc', 'fbank' + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + random_seed=0, + keep_transcription_text=False, + batch_size=32, # batch size + num_workers=0, # data loader workers + sortagrad=False, # sorted in first epoch when True + shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + @classmethod + def from_config(cls, config): + """Build a ManifestDataset object from a config. + + Args: + config (yacs.config.CfgNode): configs object. + + Returns: + ManifestDataset: dataet object. + """ + assert manifest in config.data + assert keep_transcription_text in config.data + + if isinstance(config.data.augmentation_config, (str, bytes)): + if config.data.augmentation_config: + aug_file = io.open( + config.data.augmentation_config, mode='r', encoding='utf8') + else: + aug_file = io.StringIO(initial_value='{}', newline='') + else: + aug_file = config.data.augmentation_config + assert isinstance(aug_file, io.StringIO) + + dataset = cls( + manifest_path=config.data.manifest, + unit_type=config.data.unit_type, + vocab_filepath=config.data.vocab_filepath, + mean_std_filepath=config.data.mean_std_filepath, + spm_model_prefix=config.data.spm_model_prefix, + augmentation_config=aug_file.read(), + max_input_len=config.data.max_input_len, + min_input_len=config.data.min_input_len, + max_output_len=config.data.max_output_len, + min_output_len=config.data.min_output_len, + max_output_input_ratio=config.data.max_output_input_ratio, + min_output_input_ratio=config.data.min_output_input_ratio, + stride_ms=config.data.stride_ms, + window_ms=config.data.window_ms, + n_fft=config.data.n_fft, + max_freq=config.data.max_freq, + target_sample_rate=config.data.target_sample_rate, + specgram_type=config.data.specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delat_delta, + use_dB_normalization=config.data.use_dB_normalization, + target_dB=config.data.target_dB, + random_seed=config.data.random_seed, + keep_transcription_text=config.data.keep_transcription_text) + return dataset + def __init__(self, manifest_path, unit_type, @@ -98,7 +190,8 @@ class ManifestDataset(Dataset): self._max_output_input_ratio = max_output_input_ratio, self._min_output_input_ratio = min_output_input_ratio, - self._normalizer = FeatureNormalizer(mean_std_filepath) + self._normalizer = FeatureNormalizer( + mean_std_filepath) if mean_std_filepath else None self._audio_augmentation_pipeline = AugmentationPipeline( augmentation_config=augmentation_config, random_seed=random_seed) self._speech_featurizer = SpeechFeaturizer( @@ -134,51 +227,6 @@ class ManifestDataset(Dataset): min_output_input_ratio=min_output_input_ratio) self._manifest.sort(key=lambda x: x["feat_shape"][0]) - @classmethod - def from_config(cls, config): - """Build a ManifestDataset object from a config. - - Args: - config (yacs.config.CfgNode): configs object. - - Returns: - ManifestDataset: dataet object. - """ - assert manifest in config.data - assert keep_transcription_text in config.data - if isinstance(config.data.augmentation_config, (str, bytes)): - aug_file = io.open( - config.data.augmentation_config, mode='r', encoding='utf8') - else: - aug_file = config.data.augmentation_config - assert isinstance(aug_file, io.StringIO) - dataset = cls( - manifest_path=config.data.manifest, - unit_type=config.data.unit_type, - vocab_filepath=config.data.vocab_filepath, - mean_std_filepath=config.data.mean_std_filepath, - spm_model_prefix=config.data.spm_model_prefix, - augmentation_config=aug_file.read(), - max_input_len=config.data.max_input_len, - min_input_len=config.data.min_input_len, - max_output_len=config.data.max_output_len, - min_output_len=config.data.min_output_len, - max_output_input_ratio=config.data.max_output_input_ratio, - min_output_input_ratio=config.data.min_output_input_ratio, - stride_ms=config.data.stride_ms, - window_ms=config.data.window_ms, - n_fft=config.data.n_fft, - max_freq=config.data.max_freq, - target_sample_rate=config.data.target_sample_rate, - specgram_type=config.data.specgram_type, - feat_dim=config.data.feat_dim, - delta_delta=config.data.delat_delta, - use_dB_normalization=config.data.use_dB_normalization, - target_dB=config.data.target_dB, - random_seed=config.data.random_seed, - keep_transcription_text=config.data.keep_transcription_text) - return dataset - @property def manifest(self): return self._manifest @@ -252,7 +300,8 @@ class ManifestDataset(Dataset): self._audio_augmentation_pipeline.transform_audio(speech_segment) specgram, transcript_part = self._speech_featurizer.featurize( speech_segment, self._keep_transcription_text) - specgram = self._normalizer.apply(specgram) + if self._normalizer: + specgram = self._normalizer.apply(specgram) return specgram, transcript_part def _instance_reader_creator(self, manifest): diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index abeabb76c..f563024d0 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -60,6 +60,54 @@ __all__ = ['U2TransformerModel', "U2ConformerModel"] class U2BaseModel(nn.Module): """CTC-Attention hybrid Encoder-Decoder model""" + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # network architecture + default = CfgNode() + default.cmvn_file = "" + default.cmvn_file_type = "npz" + default.input_dim = 0 + default.output_dim = 0 + # encoder related + default.encoder = 'conformer' + default.encoder_conf = CfgNode( + dict( + output_size=256, # dimension of attention + attention_heads=4, + linear_units=2048, # the number of units of position-wise feed forward + num_blocks=12, # the number of encoder blocks + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer=conv2d, # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before=true, + cnn_module_kernel=15, + use_cnn_module=True, + activation_type='swish', + pos_enc_layer_type='rel_pos', + selfattention_layer_type='rel_selfattn', )) + # decoder related + default.decoder = 'transformer' + default.decoder_conf = CfgNode( + dict( + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + self_attention_dropout_rate=0.0, + src_attention_dropout_rate=0.0, )) + # hybrid CTC/attention + default.model_conf = CfgNode( + dict( + ctc_weight=0.3, + lsm_weight=0.1, # label smoothing option + length_normalized_loss=false, )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + def __init__(self, vocab_size: int, encoder: TransformerEncoder, @@ -669,6 +717,8 @@ class U2Model(U2BaseModel): input_dim = configs['input_dim'] vocab_size = configs['output_dim'] + assert input_dim != 0, input_dim + assert vocab_size != 0, vocab_size encoder_type = configs.get('encoder', 'transformer') logger.info(f"U2 Encoder type: {encoder_type}") @@ -679,7 +729,7 @@ class U2Model(U2BaseModel): encoder = ConformerEncoder( input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) else: - raise ValueError("not support encoder type:{encoder_type}") + raise ValueError(f"not support encoder type:{encoder_type}") decoder = TransformerDecoder(vocab_size, encoder.output_size(), @@ -688,18 +738,18 @@ class U2Model(U2BaseModel): return vocab_size, encoder, decoder, ctc @classmethod - def from_pretrained(cls, dataset, config, checkpoint_path): - """Build a DeepSpeech2Model model from a pretrained model. + def from_config(cls, configs: dict): + """init model. Args: - dataset (paddle.io.Dataset): [description] - config (yacs.config.CfgNode): model configs - checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. Returns: - DeepSpeech2Model: The model built from pretrained result. + int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc """ - vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) model = cls(vocab_size=vocab_size, @@ -707,9 +757,44 @@ class U2Model(U2BaseModel): decoder=decoder, ctc=ctc, **configs['model_conf']) + return model + + @classmethod + def from_pretrained(cls, dataset, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + + Args: + dataset (paddle.io.Dataset): not used. + config (yacs.config.CfgNode): model configs + checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name - infos = checkpoint.load_parameters( - model, checkpoint_path=checkpoint_path) - logger.info(f"checkpoint info: {infos}") + Returns: + DeepSpeech2Model: The model built from pretrained result. + """ + config.input_dim = self.dataset.feature_size + config.output_dim = self.dataset.vocab_size + model = cls.from_config(config) + + if checkpoint_path: + infos = checkpoint.load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") layer_tools.summary(model) return model + + +class U2InferModel(U2Model): + def __init__(self, configs: dict): + super().__init__(configs) + + def forward(self, audio, audio_len): + """export model function + + Args: + audio (Tensor): [B, T, D] + audio_len (Tensor): [B] + + Returns: + probs: probs after softmax + """ + raise NotImplementedError("U2Model infer") diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 39bb1ccd0..6846fdc01 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -95,6 +95,8 @@ class Trainer(): self.output_dir = None self.checkpoint_dir = None self.logger = None + self.iteration = 0 + self.epoch = 0 def setup(self): """Setup the experiment. diff --git a/deepspeech/utils/socket_server.py b/deepspeech/utils/socket_server.py index 2a0a62d01..8a4f7dbc5 100644 --- a/deepspeech/utils/socket_server.py +++ b/deepspeech/utils/socket_server.py @@ -16,6 +16,7 @@ import os import random import time from time import gmtime, strftime +import socket import socketserver import struct import wave diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 9582219fd..6ec976f74 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -6,7 +6,7 @@ data: vocab_filepath: data/vocab.txt unit_type: 'spm' spm_model_prefix: 'bpe_unigram_200' - mean_std_filepath: data/mean_std.npz + mean_std_filepath: "" augmentation_config: conf/augmentation.config batch_size: 4 max_input_len: 27.0 @@ -23,7 +23,7 @@ data: max_freq: None n_fft: None stride_ms: 10.0 - window_ms: 20.0 + window_ms: 25.0 use_dB_normalization: True target_dB: -20 random_seed: 0 @@ -33,86 +33,109 @@ data: num_workers: 0 -# network architecture -# encoder related -encoder: conformer -encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: true - cnn_module_kernel: 15 - use_cnn_module: True - activation_type: 'swish' - pos_enc_layer_type: 'rel_pos' - selfattention_layer_type: 'rel_selfattn' +# # feature extraction +# collate_conf: +# # waveform level config +# wav_distortion_conf: +# wav_dither: 0.1 +# wav_distortion_rate: 0.0 +# distortion_methods: [] +# speed_perturb: true +# feature_extraction_conf: +# feature_type: 'fbank' +# mel_bins: 80 +# frame_shift: 10 +# frame_length: 25 +# using_pitch: false +# # spec level config +# # spec_swap: false +# feature_dither: 0.0 # add dither [-feature_dither,feature_dither] on fbank feature +# spec_aug: true +# spec_aug_conf: +# warp_for_time: False +# num_t_mask: 2 +# num_f_mask: 2 +# max_t: 50 +# max_f: 10 +# max_w: 80 + + +# # dataset related +# dataset_conf: +# max_length: 40960 +# min_length: 0 +# batch_type: 'static' # static or dynamic +# # the size of batch_size should be set according to your gpu memory size, here we used 2080ti gpu whose memory size is 11GB +# batch_size: 16 +# sort: true + -# decoder related -decoder: transformer -decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 -# hybrid CTC/attention -model_conf: - ctc_weight: 0.3 - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false -# feature extraction -collate_conf: - # waveform level config - wav_distortion_conf: - wav_dither: 0.1 - wav_distortion_rate: 0.0 - distortion_methods: [] - speed_perturb: true - feature_extraction_conf: - feature_type: 'fbank' - mel_bins: 80 - frame_shift: 10 - frame_length: 25 - using_pitch: false - # spec level config - # spec_swap: false - feature_dither: 0.0 # add dither [-feature_dither,feature_dither] on fbank feature - spec_aug: true - spec_aug_conf: - warp_for_time: False - num_t_mask: 2 - num_f_mask: 2 - max_t: 50 - max_f: 10 - max_w: 80 +# network architecture +model: + cmvn_file: "data/mean_std.npz" + cmvn_file_type: "npz" + # encoder related + encoder: conformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + use_cnn_module: True + cnn_module_kernel: 15 + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 -# dataset related -dataset_conf: - max_length: 40960 - min_length: 0 - batch_type: 'static' # static or dynamic - # the size of batch_size should be set according to your gpu memory size, here we used 2080ti gpu whose memory size is 11GB - batch_size: 16 - sort: true + # hybrid CTC/attention + model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false -grad_clip: 5 -accum_grad: 4 -max_epoch: 240 -log_interval: 100 -optim: adam -optim_conf: +training: + n_epoch: 20 + accum_grad: 4 + global_grad_clip: 5.0 + optim: adam + optim_conf: lr: 0.002 -scheduler: warmuplr # pytorch v1.1.0+ required -scheduler_conf: - warmup_steps: 25000 \ No newline at end of file + lr_decay: 1.0 + weight_decay: 1e-06 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + log_interval: 100 + +decoding: + batch_size: 128 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 + + diff --git a/examples/tiny/s1/local/train.sh b/examples/tiny/s1/local/train.sh new file mode 100644 index 000000000..c2e62c613 --- /dev/null +++ b/examples/tiny/s1/local/train.sh @@ -0,0 +1,18 @@ +#! /usr/bin/env bash + +export FLAGS_sync_nccl_allreduce=0 + +CUDA_VISIBLE_DEVICES=0 \ +python3 -u ${BIN_DIR}/train.py \ +--device 'gpu' \ +--nproc 1 \ +--config conf/conformer.yaml \ +--output ckpt + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + + +exit 0 diff --git a/requirements.txt b/requirements.txt index 1ef11e17d..bfc45d0b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,9 @@ -scipy==1.2.1 +pre-commit +python_speech_features resampy==0.2.2 +scipy==1.2.1 +sentencepiece SoundFile==0.9.0.post1 -python_speech_features tensorboardX -sentencepiece -yacs typeguard -pre-commit -#paddlepaddle-gpu==2.0.0 +yacs