From ab23eb5710a17c81acf60646313b588eb96a2e19 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 18 Aug 2021 12:56:13 +0000 Subject: [PATCH 01/10] fix for kaldi --- deepspeech/__init__.py | 39 -- deepspeech/exps/u2_kaldi/__init__.py | 13 + deepspeech/exps/u2_kaldi/bin/alignment.py | 54 ++ deepspeech/exps/u2_kaldi/bin/export.py | 48 ++ deepspeech/exps/u2_kaldi/bin/test.py | 55 ++ deepspeech/exps/u2_kaldi/bin/train.py | 69 ++ deepspeech/exps/u2_kaldi/model.py | 642 ++++++++++++++++++ deepspeech/io/dataloader.py | 22 +- deepspeech/io/reader.py | 7 +- deepspeech/io/sampler.py | 2 +- deepspeech/models/u2.py | 14 +- deepspeech/modules/activation.py | 2 +- deepspeech/training/optimizer.py | 54 +- deepspeech/training/scheduler.py | 51 +- examples/librispeech/s2/conf/transformer.yaml | 22 +- examples/librispeech/s2/local/train.sh | 1 + examples/librispeech/s2/path.sh | 2 +- speechnn/core/transformers/README.md | 1 - 18 files changed, 1009 insertions(+), 89 deletions(-) create mode 100644 deepspeech/exps/u2_kaldi/__init__.py create mode 100644 deepspeech/exps/u2_kaldi/bin/alignment.py create mode 100644 deepspeech/exps/u2_kaldi/bin/export.py create mode 100644 deepspeech/exps/u2_kaldi/bin/test.py create mode 100644 deepspeech/exps/u2_kaldi/bin/train.py create mode 100644 deepspeech/exps/u2_kaldi/model.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 1316256e4..88f810751 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -407,42 +407,3 @@ class GLU(nn.Layer): if not hasattr(paddle.nn, 'GLU'): logger.warn("register user GLU to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'GLU', GLU) - - -# TODO(Hui Zhang): remove this Layer -class ConstantPad2d(nn.Layer): - """Pads the input tensor boundaries with a constant value. - For N-dimensional padding, use paddle.nn.functional.pad(). - """ - - def __init__(self, padding: Union[tuple, list, int], value: float): - """ - Args: - paddle ([tuple]): the size of the padding. - If is int, uses the same padding in all boundaries. - If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom) - value ([flaot]): pad value - """ - self.padding = padding if isinstance(padding, - [tuple, list]) else [padding] * 4 - self.value = value - - def forward(self, xs: paddle.Tensor) -> paddle.Tensor: - return nn.functional.pad( - xs, - self.padding, - mode='constant', - value=self.value, - data_format='NCHW') - - -if not hasattr(paddle.nn, 'ConstantPad2d'): - logger.warn( - "register user ConstantPad2d to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d) - -########### hcak paddle.jit ############# - -if not hasattr(paddle.jit, 'export'): - logger.warn("register user export to paddle.jit, remove this when fixed!") - setattr(paddle.jit, 'export', paddle.jit.to_static) diff --git a/deepspeech/exps/u2_kaldi/__init__.py b/deepspeech/exps/u2_kaldi/__init__.py new file mode 100644 index 000000000..185a92b8d --- /dev/null +++ b/deepspeech/exps/u2_kaldi/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/exps/u2_kaldi/bin/alignment.py b/deepspeech/exps/u2_kaldi/bin/alignment.py new file mode 100644 index 000000000..3bc70f5ad --- /dev/null +++ b/deepspeech/exps/u2_kaldi/bin/alignment.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Alignment for U2 model.""" +from deepspeech.exps.u2.model import get_cfg_defaults +from deepspeech.exps.u2.model import U2Tester as Tester +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_align() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_arguments( + '--model-name', + type=str, + default='u2', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/u2_kaldi/bin/export.py b/deepspeech/exps/u2_kaldi/bin/export.py new file mode 100644 index 000000000..91967627f --- /dev/null +++ b/deepspeech/exps/u2_kaldi/bin/export.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export for U2 model.""" +from deepspeech.exps.u2.model import get_cfg_defaults +from deepspeech.exps.u2.model import U2Tester as Tester +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_export() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/u2_kaldi/bin/test.py b/deepspeech/exps/u2_kaldi/bin/test.py new file mode 100644 index 000000000..48244a545 --- /dev/null +++ b/deepspeech/exps/u2_kaldi/bin/test.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for U2 model.""" +import cProfile + +from deepspeech.exps.u2.model import get_cfg_defaults +from deepspeech.exps.u2.model import U2Tester as Tester +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + +# TODO(hui zhang): dynamic load + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + # Setting for profiling + pr = cProfile.Profile() + pr.runcall(main, config, args) + pr.dump_stats('test.profile') diff --git a/deepspeech/exps/u2_kaldi/bin/train.py b/deepspeech/exps/u2_kaldi/bin/train.py new file mode 100644 index 000000000..45ad3dbac --- /dev/null +++ b/deepspeech/exps/u2_kaldi/bin/train.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer for U2 model.""" +import cProfile +import os + +from paddle import distributed as dist +from yacs.config import CfgNode + +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.utility import print_arguments + +model_alias = { + "u2": "deepspeech.exps.u2.model:U2Trainer", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer", +} + + +def main_sp(config, args): + trainer_cls = dynamic_import(args.model_name, model_alias) + exp = trainer_cls(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.device == "gpu" and args.nprocs > 1: + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument( + '--model-name', + type=str, + default='u2_kaldi', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') + args = parser.parse_args() + print_arguments(args, globals()) + + config = CfgNode() + config.set_new_allowed(True) + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + # Setting for profiling + pr = cProfile.Profile() + pr.runcall(main, config, args) + pr.dump_stats(os.path.join(args.output, 'train.profile')) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py new file mode 100644 index 000000000..60f070a3b --- /dev/null +++ b/deepspeech/exps/u2_kaldi/model.py @@ -0,0 +1,642 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains U2 model.""" +import json +import os +import sys +import time +from collections import defaultdict +from pathlib import Path +from typing import Optional + +import numpy as np +import paddle +from paddle import distributed as dist +from yacs.config import CfgNode + +from deepspeech.io.dataloader import BatchDataLoader +from deepspeech.models.u2 import U2Model +from deepspeech.training.optimizer import OptimizerFactory +from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.trainer import Trainer +from deepspeech.utils import ctc_utils +from deepspeech.utils import error_rate +from deepspeech.utils import layer_tools +from deepspeech.utils import mp_tools +from deepspeech.utils import text_grid +from deepspeech.utils import utility +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + _C = CfgNode() + + _C.model = U2Model.params() + + _C.training = U2Trainer.params() + + _C.decoding = U2Tester.params() + + config = _C.clone() + config.set_new_allowed(True) + return config + + +class U2Trainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + n_epoch=50, # train epochs + log_interval=100, # steps + accum_grad=1, # accum grad by # steps + checkpoint=dict( + kbest_n=50, + latest_n=5, ), )) + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def train_batch(self, batch_index, batch_data, msg): + train_conf = self.config.training + start = time.time() + utt, audio, audio_len, text, text_len = batch_data + + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) + # loss div by `batch_size * accum_grad` + loss /= train_conf.accum_grad + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + losses_np = {'loss': float(loss) * train_conf.accum_grad} + if attention_loss: + losses_np['att_loss'] = float(attention_loss) + if ctc_loss: + losses_np['ctc_loss'] = float(ctc_loss) + + if (batch_index + 1) % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.lr_scheduler.step() + self.iteration += 1 + + iteration_time = time.time() - start + + if (batch_index + 1) % train_conf.log_interval == 0: + msg += "train time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config.collator.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_np.items()) + logger.info(msg) + + if dist.get_rank() == 0 and self.visualizer: + losses_np_v = losses_np.copy() + losses_np_v.update({"lr": self.lr_scheduler()}) + self.visualizer.add_scalars("step", losses_np_v, + self.iteration - 1) + + @paddle.no_grad() + def valid(self): + self.model.eval() + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + valid_losses = defaultdict(list) + num_seen_utts = 1 + total_loss = 0.0 + for i, batch in enumerate(self.valid_loader): + utt, audio, audio_len, text, text_len = batch + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + num_seen_utts += num_utts + total_loss += float(loss) * num_utts + valid_losses['val_loss'].append(float(loss)) + if attention_loss: + valid_losses['val_att_loss'].append(float(attention_loss)) + if ctc_loss: + valid_losses['val_ctc_loss'].append(float(ctc_loss)) + + if (i + 1) % self.config.training.log_interval == 0: + valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} + valid_dump['val_history_loss'] = total_loss / num_seen_utts + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_dump.items()) + logger.info(msg) + + logger.info('Rank {} Val info val_loss {}'.format( + dist.get_rank(), total_loss / num_seen_utts)) + return total_loss, num_seen_utts + + def train(self): + """The training process control by step.""" + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + # script_model = paddle.jit.to_static(self.model) + # script_model_path = str(self.checkpoint_dir / 'init') + # paddle.jit.save(script_model, script_model_path) + + from_scratch = self.resume_or_scratch() + if from_scratch: + # save init model, i.e. 0 epoch + self.save(tag='init') + + self.lr_scheduler.step(self.iteration) + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + while self.epoch < self.config.training.n_epoch: + self.model.train() + try: + data_start_time = time.time() + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts + + logger.info( + 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) + if self.visualizer: + self.visualizer.add_scalars( + 'epoch', {'cv_loss': cv_loss, + 'lr': self.lr_scheduler()}, self.epoch) + self.save(tag=self.epoch, infos={'val_loss': cv_loss}) + self.new_epoch() + + def setup_dataloader(self): + config = self.config.clone() + # train/valid dataset, return token ids + self.train_loader = BatchDataLoader( + json_file=config.data.train_manifest, + train_mode=True, + sortagrad=False, + batch_size=config.collator.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.collator.augmentation_config, + n_iter_processes=config.collator.num_workers, + subsampling_factor=1, + num_encs=1) + + self.valid_loader = BatchDataLoader( + json_file=config.data.dev_manifest, + train_mode=False, + sortagrad=False, + batch_size=config.collator.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=None, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) + + # test dataset, return raw text + self.test_loader = BatchDataLoader( + json_file=config.data.test_manifest, + train_mode=False, + sortagrad=False, + batch_size=config.collator.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=None, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) + + self.align_loader = BatchDataLoader( + json_file=config.data.test_manifest, + train_mode=False, + sortagrad=False, + batch_size=config.collator.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=None, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) + logger.info("Setup train/valid/test/align Dataloader!") + + def setup_model(self): + config = self.config + + # model + model_conf = config.model + model_conf.defrost() + model_conf.input_dim = self.train_loader.feat_dim + model_conf.output_dim = self.train_loader.vocab_size + model_conf.freeze() + model = U2Model.from_config(model_conf) + + if self.parallel: + model = paddle.DataParallel(model) + + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + # lr + scheduler_conf = config.scheduler_conf + scheduler_args = { + "learning_rate": scheduler_conf.lr, + "warmup_steps": scheduler_conf.warmup_steps, + "gamma": scheduler_conf.lr_decay, + "d_model": model_conf.encoder_conf.output_size, + "verbose": False, + } + lr_scheduler = LRSchedulerFactory.from_args(config.scheduler, + scheduler_args) + + # opt + def optimizer_args( + config, + parameters, + lr_scheduler=None, ): + optim_conf = config.optim_conf + return { + "grad_clip": optim_conf.global_grad_clip, + "weight_decay": optim_conf.weight_decay, + "learning_rate": lr_scheduler, + "parameters": parameters, + } + + optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) + optimizer = OptimizerFactory.from_args(config.optim, optimzer_args) + + self.model = model + self.lr_scheduler = lr_scheduler + self.optimizer = optimizer + logger.info("Setup model/optimizer/lr_scheduler!") + + +class U2Tester(U2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # decoding config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search', + # 'ctc_prefix_beam_search', 'attention_rescoring' + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=10, # Beam search width. + batch_size=16, # decoding batch size + ctc_weight=0.0, # ctc weight for attention rescoring decode mode. + decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. + simulate_streaming=False, # simulate streaming inference. Defaults to False. + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def ordid2token(self, texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([chr(i) for i in ids])) + return trans + + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): + cfg = self.config.decoding + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer + + start_time = time.time() + text_feature = self.test_loader.collate_fn.text_feature + target_transcripts = self.ordid2token(texts, texts_len) + result_transcripts = self.model.decode( + audio, + audio_len, + text_feature=text_feature, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + decode_time = time.time() - start_time + + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") + logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + logger.info("One example error rate [%s] = %f" % + (cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, # num examples + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type, + num_frames=audio_len.sum().numpy().item(), + decode_time=decode_time) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + assert self.args.result_file + self.model.eval() + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + + stride_ms = self.test_loader.collate_fn.stride_ms + error_rate_type = None + errors_sum, len_refs, num_ins = 0.0, 0, 0 + num_frames = 0.0 + num_time = 0.0 + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + metrics = self.compute_metrics(*batch, fout=fout) + num_frames += metrics['num_frames'] + num_time += metrics["decode_time"] + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + rtf = num_time / (num_frames * stride_ms) + logger.info( + "RTF: %f, Error rate [%s] (%d/?) = %f" % + (rtf, error_rate_type, num_ins, errors_sum / len_refs)) + + rtf = num_time / (num_frames * stride_ms) + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "RTF: {}, ".format(rtf) + msg += "Final error rate [%s] (%d/%d) = %f" % ( + error_rate_type, num_ins, num_ins, errors_sum / len_refs) + logger.info(msg) + + # test meta results + err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err' + err_type_str = "{}".format(error_rate_type) + with open(err_meta_path, 'w') as f: + data = json.dumps({ + "epoch": + self.epoch, + "step": + self.iteration, + "rtf": + rtf, + error_rate_type: + errors_sum / len_refs, + "dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0, + "process_hour": + num_time / 1000.0 / 3600.0, + "num_examples": + num_ins, + "err_sum": + errors_sum, + "ref_len": + len_refs, + "decode_method": + self.config.decoding.decoding_method, + }) + f.write(data + '\n') + + def run_test(self): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + sys.exit(-1) + + @paddle.no_grad() + def align(self): + if self.config.decoding.batch_size > 1: + logger.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + # xxx.align + assert self.args.result_file and self.args.result_file.endswith( + '.align') + + self.model.eval() + logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") + + stride_ms = self.config.collate.stride_ms + token_dict = self.align_loader.collate_fn.vocab_list + with open(self.args.result_file, 'w') as fout: + # one example in batch + for i, batch in enumerate(self.align_loader): + key, feat, feats_length, target, target_length = batch + + # 1. Encoder + encoder_out, encoder_mask = self.model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + + # 2. alignment + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = ctc_utils.forced_align(ctc_probs, target) + logger.info("align ids", key[0], alignment) + fout.write('{} {}\n'.format(key[0], alignment)) + + # 3. gen praat + # segment alignment + align_segs = text_grid.segment_alignment(alignment) + logger.info("align tokens", key[0], align_segs) + # IntervalTier, List["start end token\n"] + subsample = utility.get_subsample(self.config) + tierformat = text_grid.align_to_tierformat( + align_segs, subsample, token_dict) + # write tier + align_output_path = os.path.join( + os.path.dirname(self.args.result_file), "align") + tier_path = os.path.join(align_output_path, key[0] + ".tier") + with open(tier_path, 'w') as f: + f.writelines(tierformat) + # write textgrid + textgrid_path = os.path.join(align_output_path, + key[0] + ".TextGrid") + second_per_frame = 1. / (1000. / + stride_ms) # 25ms window, 10ms stride + second_per_example = ( + len(alignment) + 1) * subsample * second_per_frame + text_grid.generate_textgrid( + maxtime=second_per_example, + intervals=tierformat, + output=textgrid_path) + + def run_align(self): + self.resume_or_scratch() + try: + self.align() + except KeyboardInterrupt: + sys.exit(-1) + + def load_inferspec(self): + """infer model and input spec. + + Returns: + nn.Layer: inference model + List[paddle.static.InputSpec]: input spec. + """ + from deepspeech.models.u2 import U2InferModel + infer_model = U2InferModel.from_pretrained(self.test_loader, + self.config.model.clone(), + self.args.checkpoint_path) + feat_dim = self.test_loader.feat_dim + input_spec = [ + paddle.static.InputSpec(shape=[1, None, feat_dim], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[1], + dtype='int64'), # audio_length, [B] + ] + return infer_model, input_spec + + def export(self): + infer_model, input_spec = self.load_inferspec() + assert isinstance(input_spec, list), type(input_spec) + infer_model.eval() + static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) + logger.info(f"Export code: {static_model.forward.code}") + paddle.jit.save(static_model, self.args.export_path) + + def run_export(self): + try: + self.export() + except KeyboardInterrupt: + sys.exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py index 15ab73157..115fe4617 100644 --- a/deepspeech/io/dataloader.py +++ b/deepspeech/io/dataloader.py @@ -11,6 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any +from typing import Dict +from typing import List +from typing import Text + +import numpy as np from paddle.io import DataLoader from deepspeech.frontend.utility import read_manifest @@ -25,6 +31,18 @@ __all__ = ["BatchDataLoader"] logger = Log(__name__).getlog() +def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]], + mode: Text="asr", + iaxis=0, + oaxis=0): + if mode == 'asr': + feat_dim = data_json[0]['input'][oaxis]['shape'][1] + vocab_size = data_json[0]['output'][oaxis]['shape'][1] + else: + raise ValueError(f"{mode} mode not support!") + return feat_dim, vocab_size + + class BatchDataLoader(): def __init__(self, json_file: str, @@ -62,6 +80,8 @@ class BatchDataLoader(): # read json data self.data_json = read_manifest(json_file) + self.feat_dim, self.vocab_size = feat_dim_and_vocab_size( + self.data_json, mode='asr') # make minibatch list (variable length) self.minibaches = make_batchset( @@ -106,7 +126,7 @@ class BatchDataLoader(): self.dataloader = DataLoader( dataset=self.dataset, batch_size=1, - shuffle=not use_sortagrad if train_mode else False, + shuffle=not self.use_sortagrad if train_mode else False, collate_fn=lambda x: x[0], num_workers=n_iter_processes, ) diff --git a/deepspeech/io/reader.py b/deepspeech/io/reader.py index b6dc61b79..95cdbb951 100644 --- a/deepspeech/io/reader.py +++ b/deepspeech/io/reader.py @@ -66,8 +66,9 @@ class LoadInputsAndTargets(): raise ValueError("Only asr are allowed: mode={}".format(mode)) if preprocess_conf is not None: - self.preprocessing = AugmentationPipeline(preprocess_conf) - logging.warning( + with open(preprocess_conf, 'r') as fin: + self.preprocessing = AugmentationPipeline(fin.read()) + logger.warning( "[Experimental feature] Some preprocessing will be done " "for the mini-batch creation using {}".format( self.preprocessing)) @@ -197,7 +198,7 @@ class LoadInputsAndTargets(): nonzero_sorted_idx = nonzero_idx if len(nonzero_sorted_idx) != len(xs[0]): - logging.warning( + logger.warning( "Target sequences include empty tokenid (batch {} -> {}).". format(len(xs[0]), len(nonzero_sorted_idx))) diff --git a/deepspeech/io/sampler.py b/deepspeech/io/sampler.py index 3b2ef757d..763a3781e 100644 --- a/deepspeech/io/sampler.py +++ b/deepspeech/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 7ed16c9d2..c1a35560a 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -612,32 +612,32 @@ class U2BaseModel(nn.Layer): best_index = i return hyps[best_index][0] - #@jit.export + #@jit.to_static def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate - #@jit.export + #@jit.to_static def right_context(self) -> int: """ Export interface for c++ call, return right_context of the model """ return self.encoder.embed.right_context - #@jit.export + #@jit.to_static def sos_symbol(self) -> int: """ Export interface for c++ call, return sos symbol id of the model """ return self.sos - #@jit.export + #@jit.to_static def eos_symbol(self) -> int: """ Export interface for c++ call, return eos symbol id of the model """ return self.eos - @jit.export + @jit.to_static def forward_encoder_chunk( self, xs: paddle.Tensor, @@ -667,7 +667,7 @@ class U2BaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - # @jit.export([ + # @jit.to_static([ # paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D] # ]) def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: @@ -680,7 +680,7 @@ class U2BaseModel(nn.Layer): """ return self.ctc.log_softmax(xs) - @jit.export + @jit.to_static def forward_attention_decoder( self, hyps: paddle.Tensor, diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 0fe66b739..30132775e 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -69,7 +69,7 @@ class ConvGLUBlock(nn.Layer): dim=0) self.dropout_residual = nn.Dropout(p=dropout) - self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0) + self.pad_left = nn.Pad2d((0, 0, kernel_size - 1, 0), 0) layers = OrderedDict() if bottlececk_dim == 0: diff --git a/deepspeech/training/optimizer.py b/deepspeech/training/optimizer.py index f7933f8d4..db7069c98 100644 --- a/deepspeech/training/optimizer.py +++ b/deepspeech/training/optimizer.py @@ -15,6 +15,7 @@ from typing import Any from typing import Dict from typing import Text +import paddle from paddle.optimizer import Optimizer from paddle.regularizer import L2Decay @@ -43,6 +44,40 @@ def register_optimizer(cls): return cls +@register_optimizer +class Noam(paddle.optimizer.Adam): + """Seem to: espnet/nets/pytorch_backend/transformer/optimizer.py """ + + def __init__(self, + learning_rate=0, + beta1=0.9, + beta2=0.98, + epsilon=1e-9, + parameters=None, + weight_decay=None, + grad_clip=None, + lazy_mode=False, + multi_precision=False, + name=None): + super().__init__( + learning_rate=learning_rate, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + lazy_mode=lazy_mode, + multi_precision=multi_precision, + name=name) + + def __repr__(self): + echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> " + echo += f"learning_rate: {self._learning_rate}, " + echo += f"(beta1: {self._beta1} beta2: {self._beta2}), " + echo += f"epsilon: {self._epsilon}" + + def dynamic_import_optimizer(module): """Import Optimizer class dynamically. @@ -69,15 +104,18 @@ class OptimizerFactory(): args['grad_clip']) if "grad_clip" in args else None weight_decay = L2Decay( args['weight_decay']) if "weight_decay" in args else None - module_class = dynamic_import_optimizer(name.lower()) - if weight_decay: - logger.info(f'WeightDecay: {weight_decay}') + logger.info(f'') if grad_clip: - logger.info(f'GradClip: {grad_clip}') - logger.info( - f"Optimizer: {module_class.__name__} {args['learning_rate']}") + logger.info(f'') + module_class = dynamic_import_optimizer(name.lower()) args.update({"grad_clip": grad_clip, "weight_decay": weight_decay}) - - return instance_class(module_class, args) + opt = instance_class(module_class, args) + if "__repr__" in vars(opt): + logger.info(f"{opt}") + else: + logger.info( + f" LR: {args['learning_rate']}" + ) + return opt diff --git a/deepspeech/training/scheduler.py b/deepspeech/training/scheduler.py index b8f3ece7c..bb53281a8 100644 --- a/deepspeech/training/scheduler.py +++ b/deepspeech/training/scheduler.py @@ -41,22 +41,6 @@ def register_scheduler(cls): return cls -def dynamic_import_scheduler(module): - """Import Scheduler class dynamically. - - Args: - module (str): module_name:class_name or alias in `SCHEDULER_DICT` - - Returns: - type: Scheduler class - - """ - module_class = dynamic_import(module, SCHEDULER_DICT) - assert issubclass(module_class, - LRScheduler), f"{module} does not implement LRScheduler" - return module_class - - @register_scheduler class WarmupLR(LRScheduler): """The WarmupLR scheduler @@ -102,6 +86,41 @@ class WarmupLR(LRScheduler): self.step(epoch=step) +@register_scheduler +class ConstantLR(LRScheduler): + """ + Args: + learning_rate (float): The initial learning rate. It is a python float number. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``ConstantLR`` instance to schedule learning rate. + """ + + def __init__(self, learning_rate, last_epoch=-1, verbose=False): + super().__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + return self.base_lr + + +def dynamic_import_scheduler(module): + """Import Scheduler class dynamically. + + Args: + module (str): module_name:class_name or alias in `SCHEDULER_DICT` + + Returns: + type: Scheduler class + + """ + module_class = dynamic_import(module, SCHEDULER_DICT) + assert issubclass(module_class, + LRScheduler), f"{module} does not implement LRScheduler" + return module_class + + class LRSchedulerFactory(): @classmethod def from_args(cls, name: str, args: Dict[Text, Any]): diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index 8a769dca4..7710d7064 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -19,7 +19,7 @@ collator: batch_size: 64 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank - feat_dim: 80 + feat_dim: 83 delta_delta: False dither: 1.0 target_sample_rate: 16000 @@ -38,7 +38,7 @@ collator: # network architecture model: - cmvn_file: "data/mean_std.json" + cmvn_file: cmvn_file_type: "json" # encoder related encoder: transformer @@ -74,20 +74,20 @@ model: training: n_epoch: 120 accum_grad: 2 - global_grad_clip: 5.0 - optim: adam - optim_conf: - lr: 0.004 - weight_decay: 1e-06 - scheduler: warmuplr # pytorch v1.1.0+ required - scheduler_conf: - warmup_steps: 25000 - lr_decay: 1.0 log_interval: 100 checkpoint: kbest_n: 50 latest_n: 5 +optim: adam +optim_conf: + global_grad_clip: 5.0 + weight_decay: 1.0e-06 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + lr: 0.004 + warmup_steps: 25000 + lr_decay: 1.0 decoding: batch_size: 64 diff --git a/examples/librispeech/s2/local/train.sh b/examples/librispeech/s2/local/train.sh index f3eb98daf..c8bb9aafc 100755 --- a/examples/librispeech/s2/local/train.sh +++ b/examples/librispeech/s2/local/train.sh @@ -20,6 +20,7 @@ echo "using ${device}..." mkdir -p exp python3 -u ${BIN_DIR}/train.py \ +--model-name u2_kaldi \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ diff --git a/examples/librispeech/s2/path.sh b/examples/librispeech/s2/path.sh index 457f7e548..c90e27821 100644 --- a/examples/librispeech/s2/path.sh +++ b/examples/librispeech/s2/path.sh @@ -10,5 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ -MODEL=u2 +MODEL=u2_kaldi export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin diff --git a/speechnn/core/transformers/README.md b/speechnn/core/transformers/README.md index 879a88db7..edbcb9cc3 100644 --- a/speechnn/core/transformers/README.md +++ b/speechnn/core/transformers/README.md @@ -7,4 +7,3 @@ * https://github.com/NVIDIA/FasterTransformer.git * https://github.com/idiap/fast-transformers - From 0ab299a8423463c17c10aca204c03b974708b100 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 02:54:09 +0000 Subject: [PATCH 02/10] test bin --- deepspeech/exps/u2_kaldi/bin/test.py | 23 +++++++++++++++-------- deepspeech/exps/u2_kaldi/bin/train.py | 4 ++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/deepspeech/exps/u2_kaldi/bin/test.py b/deepspeech/exps/u2_kaldi/bin/test.py index 48244a545..065048274 100644 --- a/deepspeech/exps/u2_kaldi/bin/test.py +++ b/deepspeech/exps/u2_kaldi/bin/test.py @@ -14,16 +14,19 @@ """Evaluation for U2 model.""" import cProfile -from deepspeech.exps.u2.model import get_cfg_defaults -from deepspeech.exps.u2.model import U2Tester as Tester from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.utility import print_arguments -# TODO(hui zhang): dynamic load +model_alias = { + "u2": "deepspeech.exps.u2.model:U2Tester", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", +} def main_sp(config, args): - exp = Tester(config, args) + class_obj = dynamic_import(args.model_name, model_alias) + exp = class_obj(config, args) exp.setup() exp.run_test() @@ -34,13 +37,17 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument( + '--model-name', + type=str, + default='u2_kaldi', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') args = parser.parse_args() print_arguments(args, globals()) - # https://yaml.org/type/float.html - config = get_cfg_defaults() - if args.config: - config.merge_from_file(args.config) + config = CfgNode() + config.set_new_allowed(True) + config.merge_from_file(args.config) if args.opts: config.merge_from_list(args.opts) config.freeze() diff --git a/deepspeech/exps/u2_kaldi/bin/train.py b/deepspeech/exps/u2_kaldi/bin/train.py index 45ad3dbac..3a240b80f 100644 --- a/deepspeech/exps/u2_kaldi/bin/train.py +++ b/deepspeech/exps/u2_kaldi/bin/train.py @@ -29,8 +29,8 @@ model_alias = { def main_sp(config, args): - trainer_cls = dynamic_import(args.model_name, model_alias) - exp = trainer_cls(config, args) + class_obj = dynamic_import(args.model_name, model_alias) + exp = class_obj(config, args) exp.setup() exp.run() From 9dace625818a4ca282d82dcdbf2e7a18f6a8779d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 03:12:35 +0000 Subject: [PATCH 03/10] fix augmentation --- deepspeech/exps/u2_kaldi/bin/alignment.py | 54 ------------------- deepspeech/exps/u2_kaldi/bin/export.py | 48 ----------------- deepspeech/exps/u2_kaldi/bin/test.py | 17 ++++-- deepspeech/exps/u2_kaldi/bin/train.py | 4 +- deepspeech/frontend/augmentor/augmentation.py | 14 ++--- examples/librispeech/s2/local/align.sh | 3 +- examples/librispeech/s2/local/export.sh | 3 +- examples/librispeech/s2/local/test.sh | 2 + 8 files changed, 29 insertions(+), 116 deletions(-) delete mode 100644 deepspeech/exps/u2_kaldi/bin/alignment.py delete mode 100644 deepspeech/exps/u2_kaldi/bin/export.py diff --git a/deepspeech/exps/u2_kaldi/bin/alignment.py b/deepspeech/exps/u2_kaldi/bin/alignment.py deleted file mode 100644 index 3bc70f5ad..000000000 --- a/deepspeech/exps/u2_kaldi/bin/alignment.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Alignment for U2 model.""" -from deepspeech.exps.u2.model import get_cfg_defaults -from deepspeech.exps.u2.model import U2Tester as Tester -from deepspeech.training.cli import default_argument_parser -from deepspeech.utils.dynamic_import import dynamic_import -from deepspeech.utils.utility import print_arguments - - -def main_sp(config, args): - exp = Tester(config, args) - exp.setup() - exp.run_align() - - -def main(config, args): - main_sp(config, args) - - -if __name__ == "__main__": - parser = default_argument_parser() - parser.add_arguments( - '--model-name', - type=str, - default='u2', - help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') - args = parser.parse_args() - print_arguments(args, globals()) - - # https://yaml.org/type/float.html - config = get_cfg_defaults() - if args.config: - config.merge_from_file(args.config) - if args.opts: - config.merge_from_list(args.opts) - config.freeze() - print(config) - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) - - main(config, args) diff --git a/deepspeech/exps/u2_kaldi/bin/export.py b/deepspeech/exps/u2_kaldi/bin/export.py deleted file mode 100644 index 91967627f..000000000 --- a/deepspeech/exps/u2_kaldi/bin/export.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Export for U2 model.""" -from deepspeech.exps.u2.model import get_cfg_defaults -from deepspeech.exps.u2.model import U2Tester as Tester -from deepspeech.training.cli import default_argument_parser -from deepspeech.utils.utility import print_arguments - - -def main_sp(config, args): - exp = Tester(config, args) - exp.setup() - exp.run_export() - - -def main(config, args): - main_sp(config, args) - - -if __name__ == "__main__": - parser = default_argument_parser() - args = parser.parse_args() - print_arguments(args, globals()) - - # https://yaml.org/type/float.html - config = get_cfg_defaults() - if args.config: - config.merge_from_file(args.config) - if args.opts: - config.merge_from_list(args.opts) - config.freeze() - print(config) - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) - - main(config, args) diff --git a/deepspeech/exps/u2_kaldi/bin/test.py b/deepspeech/exps/u2_kaldi/bin/test.py index 065048274..457672c03 100644 --- a/deepspeech/exps/u2_kaldi/bin/test.py +++ b/deepspeech/exps/u2_kaldi/bin/test.py @@ -18,17 +18,23 @@ from deepspeech.training.cli import default_argument_parser from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.utility import print_arguments -model_alias = { +model_test_alias = { "u2": "deepspeech.exps.u2.model:U2Tester", "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", } def main_sp(config, args): - class_obj = dynamic_import(args.model_name, model_alias) + class_obj = dynamic_import(args.model_name, model_test_alias) exp = class_obj(config, args) exp.setup() - exp.run_test() + + if args.run_mode == 'test': + exp.run_test() + elif args.run_mode == 'export': + exp.run_export() + elif args.run_mode == 'align': + exp.run_align() def main(config, args): @@ -42,6 +48,11 @@ if __name__ == "__main__": type=str, default='u2_kaldi', help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') + parser.add_argument( + '--run-mode', + type=str, + default='test', + help='run mode, e.g. test, align, export') args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2_kaldi/bin/train.py b/deepspeech/exps/u2_kaldi/bin/train.py index 3a240b80f..1dcd154d3 100644 --- a/deepspeech/exps/u2_kaldi/bin/train.py +++ b/deepspeech/exps/u2_kaldi/bin/train.py @@ -22,14 +22,14 @@ from deepspeech.training.cli import default_argument_parser from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.utility import print_arguments -model_alias = { +model_train_alias = { "u2": "deepspeech.exps.u2.model:U2Trainer", "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer", } def main_sp(config, args): - class_obj = dynamic_import(args.model_name, model_alias) + class_obj = dynamic_import(args.model_name, model_train_alias) exp = class_obj(config, args) exp.setup() exp.run() diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index 7b43988e4..c479958fd 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -97,14 +97,14 @@ class AugmentationPipeline(): ValueError: If the augmentation json config is in incorrect format". """ + SPEC_TYPES = ('specaug') + def __init__(self, augmentation_config: str, random_seed: int=0): self._rng = np.random.RandomState(random_seed) - self._spec_types = ('specaug') - - if augmentation_config is None: - self.conf = {} - else: - self.conf = json.loads(augmentation_config) + self.conf = {'mode': 'sequential', 'process': []} + if augmentation_config: + process = json.loads(augmentation_config) + self.conf['process'] += process self._augmentors, self._rates = self._parse_pipeline_from('all') self._audio_augmentors, self._audio_rates = self._parse_pipeline_from( @@ -188,7 +188,7 @@ class AugmentationPipeline(): all_confs = [] for config in self.conf: all_confs.append(config) - if config["type"] in self._spec_types: + if config["type"] in self.SPEC_TYPES: feature_confs.append(config) else: audio_confs.append(config) diff --git a/examples/librispeech/s2/local/align.sh b/examples/librispeech/s2/local/align.sh index ad6c84bc8..94146ccff 100755 --- a/examples/librispeech/s2/local/align.sh +++ b/examples/librispeech/s2/local/align.sh @@ -21,7 +21,8 @@ mkdir -p ${output_dir} # align dump in `result_file` # .tier, .TextGrid dump in `dir of result_file` -python3 -u ${BIN_DIR}/alignment.py \ +python3 -u ${BIN_DIR}/test.py \ +--run_mode 'align' \ --device ${device} \ --nproc 1 \ --config ${config_path} \ diff --git a/examples/librispeech/s2/local/export.sh b/examples/librispeech/s2/local/export.sh index f99a15bad..7e42e0114 100755 --- a/examples/librispeech/s2/local/export.sh +++ b/examples/librispeech/s2/local/export.sh @@ -17,7 +17,8 @@ if [ ${ngpu} == 0 ];then device=cpu fi -python3 -u ${BIN_DIR}/export.py \ +python3 -u ${BIN_DIR}/test.py \ +--run_mode 'export' \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 3bd3f0bba..762211c23 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -38,6 +38,7 @@ for type in attention ctc_greedy_search; do batch_size=64 fi python3 -u ${BIN_DIR}/test.py \ + --run_mode test \ --device ${device} \ --nproc 1 \ --config ${config_path} \ @@ -55,6 +56,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do echo "decoding ${type}" batch_size=1 python3 -u ${BIN_DIR}/test.py \ + --run_mode test \ --device ${device} \ --nproc 1 \ --config ${config_path} \ From 9de034380730bdc027ca7ec349e7daf93292d850 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 03:16:07 +0000 Subject: [PATCH 04/10] fix augment --- deepspeech/frontend/augmentor/augmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index c479958fd..50eeea991 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -97,7 +97,7 @@ class AugmentationPipeline(): ValueError: If the augmentation json config is in incorrect format". """ - SPEC_TYPES = ('specaug') + SPEC_TYPES = {'specaug'} def __init__(self, augmentation_config: str, random_seed: int=0): self._rng = np.random.RandomState(random_seed) From 4725bace4e4b76f0d8c2c995c317ad631ce46966 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 03:19:04 +0000 Subject: [PATCH 05/10] fix --- deepspeech/frontend/augmentor/augmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index 50eeea991..17abcf605 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -186,7 +186,7 @@ class AugmentationPipeline(): audio_confs = [] feature_confs = [] all_confs = [] - for config in self.conf: + for config in self.conf['process']: all_confs.append(config) if config["type"] in self.SPEC_TYPES: feature_confs.append(config) From c09b0e894019d7de78bbc0bece1b90b44b7aff28 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 03:35:07 +0000 Subject: [PATCH 06/10] fix specaug --- README.md | 5 +++-- README_cn.md | 5 +++-- deepspeech/frontend/augmentor/base.py | 6 +++--- deepspeech/frontend/augmentor/spec_augment.py | 7 +++++-- examples/librispeech/s2/conf/augmentation.json | 17 ----------------- 5 files changed, 14 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index f7d1e0882..d10fd5d59 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ [中文版](README_cn.md) -# PaddlePaddle ASR toolkit +# PaddlePaddle Speech to Any toolkit ![License](https://img.shields.io/badge/license-Apache%202-red.svg) ![python version](https://img.shields.io/badge/python-3.7+-orange.svg) ![support os](https://img.shields.io/badge/os-linux-yellow.svg) -*PaddleASR* is an open-source implementation of end-to-end Automatic Speech Recognition (ASR) engine, with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient, samller and scalable implementation, including training, inference & testing module, and deployment. +*DeepSpeech* is an open-source implementation of end-to-end Automatic Speech Recognition engine, with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient, samller and scalable implementation, including training, inference & testing module, and deployment. ## Features @@ -15,6 +15,7 @@ ## Setup +* Ubuntu 16.04 * python>=3.7 * paddlepaddle>=2.1.2 diff --git a/README_cn.md b/README_cn.md index 019b38c15..90a65c440 100644 --- a/README_cn.md +++ b/README_cn.md @@ -1,12 +1,12 @@ [English](README.md) -# PaddlePaddle ASR toolkit +# PaddlePaddle Speech to Any toolkit ![License](https://img.shields.io/badge/license-Apache%202-red.svg) ![python version](https://img.shields.io/badge/python-3.7+-orange.svg) ![support os](https://img.shields.io/badge/os-linux-yellow.svg) -*PaddleASR*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别(ASR)引擎的开源项目, +*DeepSpeech*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别引擎的开源项目, 我们的愿景是为语音识别在工业应用和学术研究上,提供易于使用、高效、小型化和可扩展的工具,包括训练,推理,以及 部署。 ## 特性 @@ -16,6 +16,7 @@ ## 安装 +* Ubuntu 16.04 * python>=3.7 * paddlepaddle>=2.1.2 diff --git a/deepspeech/frontend/augmentor/base.py b/deepspeech/frontend/augmentor/base.py index 87cb4ef72..18d003c0b 100644 --- a/deepspeech/frontend/augmentor/base.py +++ b/deepspeech/frontend/augmentor/base.py @@ -30,7 +30,7 @@ class AugmentorBase(): @abstractmethod def __call__(self, xs): - raise NotImplementedError + raise NotImplementedError("AugmentorBase: Not impl __call__") @abstractmethod def transform_audio(self, audio_segment): @@ -44,7 +44,7 @@ class AugmentorBase(): :param audio_segment: Audio segment to add effects to. :type audio_segment: AudioSegmenet|SpeechSegment """ - raise NotImplementedError + raise NotImplementedError("AugmentorBase: Not impl transform_audio") @abstractmethod def transform_feature(self, spec_segment): @@ -56,4 +56,4 @@ class AugmentorBase(): Args: spec_segment (Spectrogram): Spectrogram segment to add effects to. """ - raise NotImplementedError + raise NotImplementedError("AugmentorBase: Not impl transform_feature") diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 94d23bf46..1786099c8 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -64,7 +64,7 @@ class SpecAugmentor(AugmentorBase): self.n_freq_masks = n_freq_masks self.n_time_masks = n_time_masks self.p = p - #logger.info(f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}") + # adaptive SpecAugment self.adaptive_number_ratio = adaptive_number_ratio @@ -120,6 +120,9 @@ class SpecAugmentor(AugmentorBase): @property def time_mask(self): return self._time_mask + + def __repr__(self): + return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}" def time_warp(xs, W=40): raise NotImplementedError @@ -160,7 +163,7 @@ class SpecAugmentor(AugmentorBase): def __call__(self, x, train=True): if not train: return - self.transform_audio(x) + self.transform_feature(x) def transform_feature(self, xs: np.ndarray): """ diff --git a/examples/librispeech/s2/conf/augmentation.json b/examples/librispeech/s2/conf/augmentation.json index c1078393d..49fe333ec 100644 --- a/examples/librispeech/s2/conf/augmentation.json +++ b/examples/librispeech/s2/conf/augmentation.json @@ -1,21 +1,4 @@ [ - { - "type": "shift", - "params": { - "min_shift_ms": -5, - "max_shift_ms": 5 - }, - "prob": 1.0 - }, - { - "type": "speed", - "params": { - "min_speed_rate": 0.9, - "max_speed_rate": 1.1, - "num_rates": 3 - }, - "prob": 0.0 - }, { "type": "specaug", "params": { From c81743403ad2b3bb62c42c008221914cf2de0897 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 03:44:54 +0000 Subject: [PATCH 07/10] fix --- README.md | 1 + README_cn.md | 2 ++ deepspeech/exps/u2_kaldi/model.py | 2 +- deepspeech/io/converter.py | 2 ++ 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d10fd5d59..de24abe2f 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ ## Setup +All tested under: * Ubuntu 16.04 * python>=3.7 * paddlepaddle>=2.1.2 diff --git a/README_cn.md b/README_cn.md index 90a65c440..29aadbdf6 100644 --- a/README_cn.md +++ b/README_cn.md @@ -16,6 +16,8 @@ ## 安装 +在以下环境测试验证过: + * Ubuntu 16.04 * python>=3.7 * paddlepaddle>=2.1.2 diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 60f070a3b..a2f062a18 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -233,7 +233,7 @@ class U2Trainer(Trainer): batch_frames_inout=0, preprocess_conf=config.collator.augmentation_config, n_iter_processes=config.collator.num_workers, - subsampling_factor=1, + subsampling_factor=0, num_encs=1) self.valid_loader = BatchDataLoader( diff --git a/deepspeech/io/converter.py b/deepspeech/io/converter.py index a02e06acb..e591a7935 100644 --- a/deepspeech/io/converter.py +++ b/deepspeech/io/converter.py @@ -55,6 +55,8 @@ class CustomConverter(): xs = [x[::self.subsampling_factor, :] for x in xs] # get batch of lengths of input sequences + print(xs) + print(ys) ilens = np.array([x.shape[0] for x in xs]) # perform padding and convert to tensor From a3e86dd8b57d4295394e626512432410943d3c3a Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 03:51:55 +0000 Subject: [PATCH 08/10] fix call --- deepspeech/frontend/augmentor/impulse_response.py | 1 + deepspeech/frontend/augmentor/noise_perturb.py | 1 + deepspeech/frontend/augmentor/online_bayesian_normalization.py | 1 + deepspeech/frontend/augmentor/resample.py | 1 + deepspeech/frontend/augmentor/shift_perturb.py | 1 + deepspeech/frontend/augmentor/spec_augment.py | 2 +- deepspeech/frontend/augmentor/speed_perturb.py | 1 + deepspeech/frontend/augmentor/volume_perturb.py | 1 + 8 files changed, 8 insertions(+), 1 deletion(-) diff --git a/deepspeech/frontend/augmentor/impulse_response.py b/deepspeech/frontend/augmentor/impulse_response.py index 01421fc65..b1a732ad8 100644 --- a/deepspeech/frontend/augmentor/impulse_response.py +++ b/deepspeech/frontend/augmentor/impulse_response.py @@ -34,6 +34,7 @@ class ImpulseResponseAugmentor(AugmentorBase): if not train: return self.transform_audio(x) + return x def transform_audio(self, audio_segment): """Add impulse response effect. diff --git a/deepspeech/frontend/augmentor/noise_perturb.py b/deepspeech/frontend/augmentor/noise_perturb.py index 11f5ed105..8be5931bc 100644 --- a/deepspeech/frontend/augmentor/noise_perturb.py +++ b/deepspeech/frontend/augmentor/noise_perturb.py @@ -40,6 +40,7 @@ class NoisePerturbAugmentor(AugmentorBase): if not train: return self.transform_audio(x) + return x def transform_audio(self, audio_segment): """Add background noise audio. diff --git a/deepspeech/frontend/augmentor/online_bayesian_normalization.py b/deepspeech/frontend/augmentor/online_bayesian_normalization.py index dc32a1808..4b5e2301e 100644 --- a/deepspeech/frontend/augmentor/online_bayesian_normalization.py +++ b/deepspeech/frontend/augmentor/online_bayesian_normalization.py @@ -48,6 +48,7 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase): if not train: return self.transform_audio(x) + return x def transform_audio(self, audio_segment): """Normalizes the input audio using the online Bayesian approach. diff --git a/deepspeech/frontend/augmentor/resample.py b/deepspeech/frontend/augmentor/resample.py index a862b184e..a8c0c6628 100644 --- a/deepspeech/frontend/augmentor/resample.py +++ b/deepspeech/frontend/augmentor/resample.py @@ -35,6 +35,7 @@ class ResampleAugmentor(AugmentorBase): if not train: return self.transform_audio(x) + return x def transform_audio(self, audio_segment): """Resamples the input audio to a target sample rate. diff --git a/deepspeech/frontend/augmentor/shift_perturb.py b/deepspeech/frontend/augmentor/shift_perturb.py index 6c78c528e..a76fb51c6 100644 --- a/deepspeech/frontend/augmentor/shift_perturb.py +++ b/deepspeech/frontend/augmentor/shift_perturb.py @@ -35,6 +35,7 @@ class ShiftPerturbAugmentor(AugmentorBase): if not train: return self.transform_audio(x) + return x def transform_audio(self, audio_segment): """Shift audio. diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 1786099c8..ed593da4a 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -163,7 +163,7 @@ class SpecAugmentor(AugmentorBase): def __call__(self, x, train=True): if not train: return - self.transform_feature(x) + return self.transform_feature(x) def transform_feature(self, xs: np.ndarray): """ diff --git a/deepspeech/frontend/augmentor/speed_perturb.py b/deepspeech/frontend/augmentor/speed_perturb.py index 838c5cc29..eec2e5511 100644 --- a/deepspeech/frontend/augmentor/speed_perturb.py +++ b/deepspeech/frontend/augmentor/speed_perturb.py @@ -83,6 +83,7 @@ class SpeedPerturbAugmentor(AugmentorBase): if not train: return self.transform_audio(x) + return x def transform_audio(self, audio_segment): """Sample a new speed rate from the given range and diff --git a/deepspeech/frontend/augmentor/volume_perturb.py b/deepspeech/frontend/augmentor/volume_perturb.py index ffae1693e..d08f75c36 100644 --- a/deepspeech/frontend/augmentor/volume_perturb.py +++ b/deepspeech/frontend/augmentor/volume_perturb.py @@ -41,6 +41,7 @@ class VolumePerturbAugmentor(AugmentorBase): if not train: return self.transform_audio(x) + return x def transform_audio(self, audio_segment): """Change audio loadness. From c484d537c25296c9c5f0e7b0fd969bd23af50f33 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 04:01:27 +0000 Subject: [PATCH 09/10] add assert --- deepspeech/io/converter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepspeech/io/converter.py b/deepspeech/io/converter.py index e591a7935..3bfcc1b1e 100644 --- a/deepspeech/io/converter.py +++ b/deepspeech/io/converter.py @@ -49,14 +49,13 @@ class CustomConverter(): # batch should be located in list assert len(batch) == 1 (xs, ys), utts = batch[0] + assert xs[0] is not None, "please check Reader and Augmentation impl." # perform subsampling if self.subsampling_factor > 1: xs = [x[::self.subsampling_factor, :] for x in xs] # get batch of lengths of input sequences - print(xs) - print(ys) ilens = np.array([x.shape[0] for x in xs]) # perform padding and convert to tensor From d64cdc7838468771e4037ca2cfaa6f1089a1ee80 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 04:40:06 +0000 Subject: [PATCH 10/10] fix --- README_cn.md | 2 +- deepspeech/exps/u2_kaldi/model.py | 2 +- deepspeech/frontend/augmentor/spec_augment.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/README_cn.md b/README_cn.md index 29aadbdf6..4b9273625 100644 --- a/README_cn.md +++ b/README_cn.md @@ -16,7 +16,7 @@ ## 安装 -在以下环境测试验证过: +在以下环境测试验证过: * Ubuntu 16.04 * python>=3.7 diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index a2f062a18..60f070a3b 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -233,7 +233,7 @@ class U2Trainer(Trainer): batch_frames_inout=0, preprocess_conf=config.collator.augmentation_config, n_iter_processes=config.collator.num_workers, - subsampling_factor=0, + subsampling_factor=1, num_encs=1) self.valid_loader = BatchDataLoader( diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index ed593da4a..bfa8300af 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -64,7 +64,6 @@ class SpecAugmentor(AugmentorBase): self.n_freq_masks = n_freq_masks self.n_time_masks = n_time_masks self.p = p - # adaptive SpecAugment self.adaptive_number_ratio = adaptive_number_ratio @@ -120,7 +119,7 @@ class SpecAugmentor(AugmentorBase): @property def time_mask(self): return self._time_mask - + def __repr__(self): return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"