diff --git a/README.md b/README.md index 0eeab7706..931e6331c 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ You are welcome to submit questions in [Github Discussions](https://github.com/P ## License -DeepASR is provided under the [Apache-2.0 License](./LICENSE). +DeepSpeech is provided under the [Apache-2.0 License](./LICENSE). ## Acknowledgement diff --git a/README_cn.md b/README_cn.md index 3ff668956..cc993f8bf 100644 --- a/README_cn.md +++ b/README_cn.md @@ -42,7 +42,7 @@ ## License -DeepASR 遵循[Apache-2.0开源协议](./LICENSE)。 +DeepSpeech 遵循[Apache-2.0开源协议](./LICENSE)。 ## 感谢 diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index d85a3dde7..5505ecbf0 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -80,23 +80,23 @@ def convert_dtype_to_string(tensor_dtype): if not hasattr(paddle, 'softmax'): - logger.warn("register user softmax to paddle, remove this when fixed!") + logger.debug("register user softmax to paddle, remove this when fixed!") setattr(paddle, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle, 'log_softmax'): - logger.warn("register user log_softmax to paddle, remove this when fixed!") + logger.debug("register user log_softmax to paddle, remove this when fixed!") setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) if not hasattr(paddle, 'sigmoid'): - logger.warn("register user sigmoid to paddle, remove this when fixed!") + logger.debug("register user sigmoid to paddle, remove this when fixed!") setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle, 'log_sigmoid'): - logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + logger.debug("register user log_sigmoid to paddle, remove this when fixed!") setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) if not hasattr(paddle, 'relu'): - logger.warn("register user relu to paddle, remove this when fixed!") + logger.debug("register user relu to paddle, remove this when fixed!") setattr(paddle, 'relu', paddle.nn.functional.relu) @@ -105,7 +105,7 @@ def cat(xs, dim=0): if not hasattr(paddle, 'cat'): - logger.warn( + logger.debug( "override cat of paddle if exists or register, remove this when fixed!") paddle.cat = cat @@ -116,7 +116,7 @@ def item(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'item'): - logger.warn( + logger.debug( "override item of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.item = item @@ -127,13 +127,13 @@ def func_long(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'long'): - logger.warn( + logger.debug( "override long of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.long = func_long if not hasattr(paddle.Tensor, 'numel'): - logger.warn( + logger.debug( "override numel of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.numel = paddle.numel @@ -147,7 +147,7 @@ def new_full(x: paddle.Tensor, if not hasattr(paddle.Tensor, 'new_full'): - logger.warn( + logger.debug( "override new_full of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.new_full = new_full @@ -162,13 +162,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'eq'): - logger.warn( + logger.debug( "override eq of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.eq = eq if not hasattr(paddle, 'eq'): - logger.warn( + logger.debug( "override eq of paddle if exists or register, remove this when fixed!") paddle.eq = eq @@ -178,7 +178,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'contiguous'): - logger.warn( + logger.debug( "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.contiguous = contiguous @@ -195,7 +195,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: #`to_static` do not process `size` property, maybe some `paddle` api dependent on it. -logger.warn( +logger.debug( "override size of paddle.Tensor " "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" ) @@ -207,7 +207,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view'): - logger.warn("register user view to paddle.Tensor, remove this when fixed!") + logger.debug("register user view to paddle.Tensor, remove this when fixed!") paddle.Tensor.view = view @@ -216,7 +216,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view_as'): - logger.warn( + logger.debug( "register user view_as to paddle.Tensor, remove this when fixed!") paddle.Tensor.view_as = view_as @@ -242,7 +242,7 @@ def masked_fill(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill'): - logger.warn( + logger.debug( "register user masked_fill to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill = masked_fill @@ -260,7 +260,7 @@ def masked_fill_(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill_'): - logger.warn( + logger.debug( "register user masked_fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill_ = masked_fill_ @@ -272,7 +272,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'fill_'): - logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.fill_ = fill_ @@ -281,22 +282,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'repeat'): - logger.warn( + logger.debug( "register user repeat to paddle.Tensor, remove this when fixed!") paddle.Tensor.repeat = repeat if not hasattr(paddle.Tensor, 'softmax'): - logger.warn( + logger.debug( "register user softmax to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle.Tensor, 'sigmoid'): - logger.warn( + logger.debug( "register user sigmoid to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle.Tensor, 'relu'): - logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + logger.debug("register user relu to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) @@ -305,7 +306,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'type_as'): - logger.warn( + logger.debug( "register user type_as to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'type_as', type_as) @@ -321,7 +322,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'to'): - logger.warn("register user to to paddle.Tensor, remove this when fixed!") + logger.debug("register user to to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'to', to) @@ -330,7 +331,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'float'): - logger.warn("register user float to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user float to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'float', func_float) @@ -339,7 +341,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'int'): - logger.warn("register user int to paddle.Tensor, remove this when fixed!") + logger.debug("register user int to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'int', func_int) @@ -348,23 +350,6 @@ def tolist(x: paddle.Tensor) -> List[Any]: if not hasattr(paddle.Tensor, 'tolist'): - logger.warn( + logger.debug( "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) - - -########### hcak paddle.nn ############# -class GLU(nn.Layer): - """Gated Linear Units (GLU) Layer""" - - def __init__(self, dim: int=-1): - super().__init__() - self.dim = dim - - def forward(self, xs): - return F.glu(xs, axis=self.dim) - - -if not hasattr(paddle.nn, 'GLU'): - logger.warn("register user GLU to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'GLU', GLU) diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp index 4dcc7c899..fcb1f7642 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp @@ -35,7 +35,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { @@ -48,7 +49,7 @@ std::vector> ctc_beam_search_decoder( // assign blank id // size_t blank_id = vocabulary.size(); - size_t blank_id = 0; + // size_t blank_id = 0; // assign space id auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); @@ -57,7 +58,6 @@ std::vector> ctc_beam_search_decoder( if ((size_t)space_id >= vocabulary.size()) { space_id = -2; } - // init prefixes' root PathTrie root; root.score = root.log_prob_b_prev = 0.0; @@ -218,7 +218,8 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); // thread pool ThreadPool pool(num_processes); @@ -234,7 +235,8 @@ ctc_beam_search_decoder_batch( beam_size, cutoff_prob, cutoff_top_n, - ext_scorer)); + ext_scorer, + blank_id)); } // get decoding results diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.h b/deepspeech/decoders/swig/ctc_beam_search_decoder.h index c31510da3..eaba9da8c 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.h +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.h @@ -43,7 +43,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); /* CTC Beam Search Decoder for batch data @@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp index 1c735c424..18008cced 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp @@ -17,17 +17,18 @@ std::string ctc_greedy_decoder( const std::vector> &probs_seq, - const std::vector &vocabulary) { + const std::vector &vocabulary, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size() + 1, + vocabulary.size(), "The shape of probs_seq does not match with " "the shape of the vocabulary"); } - size_t blank_id = vocabulary.size(); + // size_t blank_id = vocabulary.size(); std::vector max_idx_vec(num_time_steps, 0); std::vector idx_vec; diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.h b/deepspeech/decoders/swig/ctc_greedy_decoder.h index 5e8c5c251..dd1b33315 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.h +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.h @@ -29,6 +29,7 @@ */ std::string ctc_greedy_decoder( const std::vector>& probs_seq, - const std::vector& vocabulary); + const std::vector& vocabulary, + size_t blank_id); #endif // CTC_GREEDY_DECODER_H diff --git a/deepspeech/decoders/swig/setup.py b/deepspeech/decoders/swig/setup.py index 8fb792962..c089f96cd 100644 --- a/deepspeech/decoders/swig/setup.py +++ b/deepspeech/decoders/swig/setup.py @@ -85,9 +85,8 @@ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') # yapf: disable FILES = [ - fn for fn in FILES - if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( - 'unittest.cc')) + fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') + or fn.endswith('unittest.cc')) ] # yapf: enable diff --git a/deepspeech/decoders/swig_wrapper.py b/deepspeech/decoders/swig_wrapper.py index 3ffdb9c74..d883d430c 100644 --- a/deepspeech/decoders/swig_wrapper.py +++ b/deepspeech/decoders/swig_wrapper.py @@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer): swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) -def ctc_greedy_decoder(probs_seq, vocabulary): +def ctc_greedy_decoder(probs_seq, vocabulary, blank_id): """Wrapper for ctc best path decoder in swig. :param probs_seq: 2-D list of probability distributions over each time @@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): :return: Decoding result string. :rtype: str """ - result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) + result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary, + blank_id) return result @@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq, beam_size, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time @@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq, """ beam_results = swig_decoders.ctc_beam_search_decoder( probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, - ext_scoring_func) + ext_scoring_func, blank_id) beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] return beam_results @@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split, num_processes, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the batched CTC beam search decoder. :param probs_seq: 3-D list with each element as an instance of 2-D list @@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split, batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( probs_split, vocabulary, beam_size, num_processes, cutoff_prob, - cutoff_top_n, ext_scoring_func) + cutoff_top_n, ext_scoring_func, blank_id) batch_beam_results = [[(res[0], res[1]) for res in beam_results] for beam_results in batch_beam_results] return batch_beam_results diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index f3e3fcadf..fbc357ca0 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -15,6 +15,7 @@ import os import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -65,29 +66,51 @@ class DeepSpeech2Trainer(Trainer): super().__init__(config, args) def train_batch(self, batch_index, batch_data, msg): + train_conf = self.config.training start = time.time() + + # forward utt, audio, audio_len, text, text_len = batch_data loss = self.model(audio, audio_len, text, text_len) - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - self.optimizer.step() - self.optimizer.clear_grad() - iteration_time = time.time() - start - losses_np = { 'train_loss': float(loss), } + + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step + if (batch_index + 1) % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.iteration += 1 + + iteration_time = time.time() - start + msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "batch size: {}, ".format(self.config.collator.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) logger.info(msg) if dist.get_rank() == 0 and self.visualizer: for k, v in losses_np.items(): + # `step -1` since we update `step` after optimizer.step(). self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration) - self.iteration += 1 + self.iteration - 1) @paddle.no_grad() def valid(self): diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index 9dd0041dd..fef615ce3 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -21,6 +21,7 @@ from deepspeech.exps.u2.config import get_cfg_defaults from deepspeech.exps.u2.model import U2Trainer as Trainer from deepspeech.training.cli import default_argument_parser from deepspeech.utils.utility import print_arguments +# from deepspeech.exps.u2.trainer import U2Trainer as Trainer def main_sp(config, args): diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 0662e38d9..2b6e24330 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -33,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -79,21 +81,35 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() - utt, audio, audio_len, text, text_len = batch_data + # forward + utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -169,40 +185,42 @@ class U2Trainer(Trainer): self.save(tag='init') self.lr_scheduler.step(self.iteration) - if self.parallel: + if self.parallel and hasattr(self.train_loader, 'batch_sampler'): self.train_loader.batch_sampler.set_epoch(self.epoch) logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) diff --git a/deepspeech/exps/u2/trainer.py b/deepspeech/exps/u2/trainer.py new file mode 100644 index 000000000..fa3e6d9d7 --- /dev/null +++ b/deepspeech/exps/u2/trainer.py @@ -0,0 +1,219 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains U2 model.""" +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.models.u2 import U2Evaluator +from deepspeech.models.u2 import U2Model +from deepspeech.models.u2 import U2Updater +from deepspeech.training.extensions.snapshot import Snapshot +from deepspeech.training.extensions.visualizer import VisualDL +from deepspeech.training.optimizer import OptimizerFactory +from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer +from deepspeech.training.trainer import Trainer +from deepspeech.training.updaters.trainer import Trainer as NewTrainer +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class U2Trainer(Trainer): + def __init__(self, config, args): + super().__init__(config, args) + + def setup_dataloader(self): + config = self.config.clone() + config.defrost() + config.collator.keep_transcription_text = False + + # train/valid dataset, return token ids + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers, ) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.collator.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev) + + # test dataset, return raw text + config.data.manifest = config.data.test_manifest + # filter test examples, will cause less examples, but no mismatch with training + # and can use large batch size , save training time, so filter test egs now. + config.data.min_input_len = 0.0 # second + config.data.max_input_len = float('inf') # second + config.data.min_output_len = 0.0 # tokens + config.data.max_output_len = float('inf') # tokens + config.data.min_output_input_ratio = 0.00 + config.data.max_output_input_ratio = float('inf') + + test_dataset = ManifestDataset.from_config(config) + # return text ord id + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + # return text token id + config.collator.keep_transcription_text = False + self.align_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + logger.info("Setup train/valid/test/align Dataloader!") + + def setup_model(self): + config = self.config + model_conf = config.model + model_conf.defrost() + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + model_conf.freeze() + model = U2Model.from_config(model_conf) + + if self.parallel: + model = paddle.DataParallel(model) + + model.train() + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + + scheduler_args = { + "learning_rate": optim_conf.lr, + "verbose": False, + "warmup_steps": scheduler_conf.warmup_steps, + "gamma": scheduler_conf.lr_decay, + "d_model": model_conf.encoder_conf.output_size, + } + lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, + scheduler_args) + + def optimizer_args( + config, + parameters, + lr_scheduler=None, ): + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + return { + "grad_clip": train_config.global_grad_clip, + "weight_decay": optim_conf.weight_decay, + "learning_rate": lr_scheduler + if lr_scheduler else optim_conf.lr, + "parameters": parameters, + "epsilon": 1e-9 if optim_type == 'noam' else None, + "beta1": 0.9 if optim_type == 'noam' else None, + "beat2": 0.98 if optim_type == 'noam' else None, + } + + optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) + optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + logger.info("Setup model/optimizer/lr_scheduler!") + + def setup_updater(self): + output_dir = self.output_dir + config = self.config.training + + updater = U2Updater( + model=self.model, + optimizer=self.optimizer, + scheduler=self.lr_scheduler, + dataloader=self.train_loader, + output_dir=output_dir, + accum_grad=config.accum_grad) + + trainer = NewTrainer(updater, (config.n_epoch, 'epoch'), output_dir) + + evaluator = U2Evaluator(self.model, self.valid_loader) + + trainer.extend(evaluator, trigger=(1, "epoch")) + + if dist.get_rank() == 0: + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + num_snapshots = config.checkpoint.kbest_n + trainer.extend( + Snapshot( + mode='kbest', + max_size=num_snapshots, + indicator='VALID/LOSS', + less_better=True), + trigger=(1, 'epoch')) + # print(trainer.extensions) + # trainer.run() + self.trainer = trainer + + def run(self): + """The routine of the experiment after setup. This method is intended + to be used by the user. + """ + self.setup_updater() + with Timer("Training Done: {}"): + self.trainer.run() diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 6a932d751..095dfe34d 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -31,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -83,20 +85,34 @@ class U2Trainer(Trainer): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -175,35 +191,37 @@ class U2Trainer(Trainer): logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 5734e15f5..8dca16540 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -37,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2_st import U2STModel from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.scheduler import WarmupLR +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import bleu_score from deepspeech.utils import ctc_utils @@ -83,6 +85,7 @@ class U2STTrainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data if isinstance(text, list) and isinstance(text_len, list): # joint training with ASR. Two decoding texts [translation, transcription] @@ -94,18 +97,30 @@ class U2STTrainer(Trainer): else: loss, st_loss, attention_loss, ctc_loss = self.model( audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} - losses_np['st_loss'] = float(st_loss) if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -193,35 +208,37 @@ class U2STTrainer(Trainer): logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py index a35a0bc09..310f5f581 100644 --- a/deepspeech/io/dataloader.py +++ b/deepspeech/io/dataloader.py @@ -44,7 +44,7 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]], def batch_collate(x): - """de-tuple. + """de-minibatch, since user compose batch. Args: x (List[Tuple]): [(utts, xs, ilens, ys, olens)] diff --git a/deepspeech/models/ds2/conv.py b/deepspeech/models/ds2/conv.py index ce962a445..9548af0a2 100644 --- a/deepspeech/models/ds2/conv.py +++ b/deepspeech/models/ds2/conv.py @@ -106,11 +106,9 @@ class ConvBn(nn.Layer): # reset padding part to 0 masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 5f8f32557..dda26358b 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -128,8 +128,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, #Number of stacking RNN layers. rnn_layer_size=1024, #RNN layer size (number of RNN cells). use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) + share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + ctc_grad_norm_type='instance', )) if config is not None: config.merge_from_other_cfg(default) return default @@ -141,7 +141,9 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + blank_id=0, + ctc_grad_norm_type='instance'): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -156,10 +158,11 @@ class DeepSpeech2Model(nn.Layer): self.decoder = CTCDecoder( odim=dict_size, # is in vocab enc_n_units=self.encoder.output_size, - blank_id=0, # first token is + blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=ctc_grad_norm_type) def forward(self, audio, audio_len, text, text_len): """Compute Model loss @@ -221,7 +224,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + share_rnn_weights=config.model.share_rnn_weights, + blank_id=config.model.blank_id) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -246,7 +250,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.num_rnn_layers, rnn_size=config.rnn_layer_size, use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights) + share_rnn_weights=config.share_rnn_weights, + blank_id=config.blank_id) return model @@ -258,7 +263,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + blank_id=0): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -266,7 +272,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=num_rnn_layers, rnn_size=rnn_size, use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + share_rnn_weights=share_rnn_weights, + blank_id=blank_id) def forward(self, audio, audio_len): """export model function diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py index 3ff91d0af..3fc52a378 100644 --- a/deepspeech/models/ds2/rnn.py +++ b/deepspeech/models/ds2/rnn.py @@ -308,7 +308,8 @@ class RNNStack(nn.Layer): x, x_len = rnn(x, x_len) masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks + return x, x_len diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index f597a5783..29d207c44 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer): num_fc_layers=2, fc_layers_size_list=[512, 256], use_gru=True, #Use gru if set True. Use simple rnn if set False. + blank_id=0, # index of blank in vocob.txt )) if config is not None: config.merge_from_other_cfg(default) @@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False): + use_gru=False, + blank_id=0): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -284,10 +286,11 @@ class DeepSpeech2ModelOnline(nn.Layer): self.decoder = CTCDecoder( odim=dict_size, # is in vocab enc_n_units=self.encoder.output_size, - blank_id=0, # first token is + blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type='instance') def forward(self, audio, audio_len, text, text_len): """Compute Model loss @@ -353,7 +356,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction=config.model.rnn_direction, num_fc_layers=config.model.num_fc_layers, fc_layers_size_list=config.model.fc_layers_size_list, - use_gru=config.model.use_gru) + use_gru=config.model.use_gru, + blank_id=config.model.blank_id) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -380,7 +384,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction=config.rnn_direction, num_fc_layers=config.num_fc_layers, fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru) + use_gru=config.use_gru, + blank_id=config.blank_id) return model @@ -394,7 +399,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False): + use_gru=False, + blank_id=0): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -404,7 +410,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): rnn_direction=rnn_direction, num_fc_layers=num_fc_layers, fc_layers_size_list=fc_layers_size_list, - use_gru=use_gru) + use_gru=use_gru, + blank_id=blank_id) def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box): diff --git a/deepspeech/models/u2/__init__.py b/deepspeech/models/u2/__init__.py new file mode 100644 index 000000000..a9010f1d0 --- /dev/null +++ b/deepspeech/models/u2/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .u2 import U2InferModel +from .u2 import U2Model +from .updater import U2Evaluator +from .updater import U2Updater + +__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"] diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2/u2.py similarity index 98% rename from deepspeech/models/u2.py rename to deepspeech/models/u2/u2.py index 1ca6a4feb..fd8f15471 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2/u2.py @@ -115,7 +115,8 @@ class U2BaseModel(nn.Layer): ctc_weight: float=0.5, ignore_id: int=IGNORE_ID, lsm_weight: float=0.0, - length_normalized_loss: bool=False): + length_normalized_loss: bool=False, + **kwargs): assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() @@ -661,9 +662,7 @@ class U2BaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - # @jit.to_static([ - # paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D] - # ]) + # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -830,6 +829,7 @@ class U2Model(U2BaseModel): Returns: int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc """ + # cmvn if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['cmvn_file_type']) @@ -839,11 +839,13 @@ class U2Model(U2BaseModel): else: global_cmvn = None + # input & output dim input_dim = configs['input_dim'] vocab_size = configs['output_dim'] assert input_dim != 0, input_dim assert vocab_size != 0, vocab_size + # encoder encoder_type = configs.get('encoder', 'transformer') logger.info(f"U2 Encoder type: {encoder_type}") if encoder_type == 'transformer': @@ -855,16 +857,21 @@ class U2Model(U2BaseModel): else: raise ValueError(f"not support encoder type:{encoder_type}") + # decoder decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + + # ctc decoder and ctc loss + model_conf = configs['model_conf'] ctc = CTCDecoder( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=0.0, + dropout_rate=model_conf['ctc_dropoutrate'], reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=model_conf['ctc_grad_norm_type']) return vocab_size, encoder, decoder, ctc diff --git a/deepspeech/models/u2/updater.py b/deepspeech/models/u2/updater.py new file mode 100644 index 000000000..7b70ca047 --- /dev/null +++ b/deepspeech/models/u2/updater.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +import paddle +from paddle import distributed as dist + +from deepspeech.training.extensions.evaluator import StandardEvaluator +from deepspeech.training.reporter import report +from deepspeech.training.timer import Timer +from deepspeech.training.updaters.standard_updater import StandardUpdater +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class U2Evaluator(StandardEvaluator): + def __init__(self, model, dataloader): + super().__init__(model, dataloader) + self.msg = "" + self.num_seen_utts = 0 + self.total_loss = 0.0 + + def evaluate_core(self, batch): + self.msg = "Valid: Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + self.num_seen_utts += num_utts + self.total_loss += float(loss) * num_utts + + losses_dict['loss'] = float(loss) + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + + for k, v in losses_dict.items(): + report("eval/" + k, v) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + logger.info(self.msg) + return self.total_loss, self.num_seen_utts + + +class U2Updater(StandardUpdater): + def __init__(self, + model, + optimizer, + scheduler, + dataloader, + init_state=None, + accum_grad=1, + **kwargs): + super().__init__( + model, optimizer, scheduler, dataloader, init_state=init_state) + self.accum_grad = accum_grad + self.forward_count = 0 + self.msg = "" + + def update_core(self, batch): + """One Step + + Args: + batch (List[Object]): utts, xs, xlens, ys, ylens + """ + losses_dict = {} + self.msg = "Rank: {}, ".format(dist.get_rank()) + + # forward + batch_size = batch[1].shape[0] + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + # loss div by `batch_size * accum_grad` + loss /= self.accum_grad + + # loss backward + if (self.forward_count + 1) != self.accum_grad: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # loss info + losses_dict['loss'] = float(loss) * self.accum_grad + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + # report loss + for k, v in losses_dict.items(): + report("train/" + k, v) + # loss msg + self.msg += "batch size: {}, ".format(batch_size) + self.msg += "accum: {}, ".format(self.accum_grad) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + # Truncate the graph + loss.detach() + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + + self.optimizer.step() + self.optimizer.clear_grad() + self.scheduler.step() + + def update(self): + # model is default in train mode + + # training for a step is implemented here + with Timer("data time cost:{}"): + batch = self.read_batch() + with Timer("step time cost:{}"): + self.update_core(batch) + + # #iterations with accum_grad > 1 + # Ref.: https://github.com/espnet/espnet/issues/777 + if self.forward_count == 0: + self.state.iteration += 1 + if self.updates_per_epoch is not None: + if self.state.iteration % self.updates_per_epoch == 0: + self.state.epoch += 1 diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py index 531fafd0d..6737a549d 100644 --- a/deepspeech/models/u2_st.py +++ b/deepspeech/models/u2_st.py @@ -413,26 +413,26 @@ class U2STBaseModel(nn.Layer): best_hyps = best_hyps[:, 1:] return best_hyps - @jit.to_static + # @jit.to_static def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate - @jit.to_static + # @jit.to_static def right_context(self) -> int: """ Export interface for c++ call, return right_context of the model """ return self.encoder.embed.right_context - @jit.to_static + # @jit.to_static def sos_symbol(self) -> int: """ Export interface for c++ call, return sos symbol id of the model """ return self.sos - @jit.to_static + # @jit.to_static def eos_symbol(self) -> int: """ Export interface for c++ call, return eos symbol id of the model """ @@ -468,7 +468,7 @@ class U2STBaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - @jit.to_static + # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -643,13 +643,16 @@ class U2STModel(U2STBaseModel): decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + # ctc decoder and ctc loss + model_conf = configs['model_conf'] ctc = CTCDecoder( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=0.0, + dropout_rate=model_conf['ctc_dropout_rate'], reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=model_conf['ctc_grad_norm_type']) return vocab_size, encoder, (st_decoder, decoder, ctc) else: diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 30132775e..3cb8729e1 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -15,12 +15,13 @@ from collections import OrderedDict import paddle from paddle import nn +from paddle.nn import functional as F from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] +__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return x.maximum(t_min).minimum(t_max) +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return F.glu(xs, axis=self.dim) + + class LinearGLUBlock(nn.Layer): """A linear Gated Linear Units (GLU) block.""" @@ -133,13 +145,18 @@ def get_activation(act): """Return activation function.""" # Lazy load to avoid unused import activation_funcs = { + "hardshrink": paddle.nn.Hardshrink, + "hardswish": paddle.nn.Hardswish, "hardtanh": paddle.nn.Hardtanh, "tanh": paddle.nn.Tanh, "relu": paddle.nn.ReLU, + "relu6": paddle.nn.ReLU6, + "leakyrelu": paddle.nn.LeakyReLU, "selu": paddle.nn.SELU, "swish": paddle.nn.Swish, "gelu": paddle.nn.GELU, - "brelu": brelu, + "glu": GLU, + "elu": paddle.nn.ELU, } return activation_funcs[act]() diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py index 8bf48b2c8..22a168800 100644 --- a/deepspeech/modules/conv.py +++ b/deepspeech/modules/conv.py @@ -113,11 +113,9 @@ class ConvBn(nn.Layer): # reset padding part to 0 masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 10c046383..b3ca28279 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -39,7 +39,8 @@ class CTCDecoder(nn.Layer): blank_id=0, dropout_rate: float=0.0, reduction: bool=True, - batch_average: bool=True): + batch_average: bool=True, + grad_norm_type: str="instance"): """CTC decoder Args: @@ -48,6 +49,7 @@ class CTCDecoder(nn.Layer): dropout_rate (float): dropout rate (0.0 ~ 1.0) reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' batch_average (bool): do batch dim wise average. + grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None. """ assert check_argument_types() super().__init__() @@ -60,7 +62,8 @@ class CTCDecoder(nn.Layer): self.criterion = CTCLoss( blank=self.blank_id, reduction=reduction_type, - batch_average=batch_average) + batch_average=batch_average, + grad_norm_type=grad_norm_type) # CTCDecoder LM Score handle self._ext_scorer = None @@ -136,7 +139,7 @@ class CTCDecoder(nn.Layer): results = [] for i, probs in enumerate(probs_split): output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) + probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id) results.append(output_transcription) return results @@ -216,7 +219,8 @@ class CTCDecoder(nn.Layer): num_processes=num_processes, ext_scoring_func=self._ext_scorer, cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) + cutoff_top_n=cutoff_top_n, + blank_id=self.blank_id) results = [result[0][1] for result in beam_search_results] return results diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index f692a8186..2c58be7e3 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -23,11 +23,32 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] class CTCLoss(nn.Layer): - def __init__(self, blank=0, reduction='sum', batch_average=False): + def __init__(self, + blank=0, + reduction='sum', + batch_average=False, + grad_norm_type=None): super().__init__() # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.batch_average = batch_average + logger.info( + f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") + + # instance for norm_by_times + # batch for norm_by_batchsize + # frame for norm_by_total_logits_len + assert grad_norm_type in ('instance', 'batch', 'frame', None) + self.norm_by_times = False + self.norm_by_batchsize = False + self.norm_by_total_logits_len = False + logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") + if grad_norm_type == 'instance': + self.norm_by_times = True + if grad_norm_type == 'batch': + self.norm_by_batchsize = True + if grad_norm_type == 'frame': + self.norm_by_total_logits_len = True def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. @@ -46,10 +67,15 @@ class CTCLoss(nn.Layer): # warp-ctc need activation with shape [T, B, V + 1] # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) - # (TODO:Hui Zhang) ctc loss does not support int64 labels ys_pad = ys_pad.astype(paddle.int32) loss = self.loss( - logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average) + logits, + ys_pad, + hlens, + ys_lens, + norm_by_times=self.norm_by_times, + norm_by_batchsize=self.norm_by_batchsize, + norm_by_total_logits_len=self.norm_by_total_logits_len) if self.batch_average: # Batch-size average loss = loss / B diff --git a/deepspeech/modules/rnn.py b/deepspeech/modules/rnn.py index 0d8c9fd2c..8f8b2a18d 100644 --- a/deepspeech/modules/rnn.py +++ b/deepspeech/modules/rnn.py @@ -308,7 +308,7 @@ class RNNStack(nn.Layer): x, x_len = rnn(x, x_len) masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py index 96ff967f5..d5b359829 100644 --- a/deepspeech/training/extensions/evaluator.py +++ b/deepspeech/training/extensions/evaluator.py @@ -13,14 +13,18 @@ # limitations under the License. from typing import Dict -import extension import paddle +from paddle import distributed as dist from paddle.io import DataLoader from paddle.nn import Layer +from . import extension from ..reporter import DictSummary from ..reporter import report from ..reporter import scope +from ..timer import Timer +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() class StandardEvaluator(extension.Extension): @@ -43,6 +47,27 @@ class StandardEvaluator(extension.Extension): def evaluate_core(self, batch): # compute self.model(batch) # you may report here + return + + def evaluate_sync(self, data): + # dist sync `evaluate_core` outputs + if data is None: + return + + numerator, denominator = data + if dist.get_world_size() > 1: + numerator = paddle.to_tensor(numerator) + denominator = paddle.to_tensor(denominator) + # the default operator in all_reduce function is sum. + dist.all_reduce(numerator) + dist.all_reduce(denominator) + value = numerator / denominator + value = float(value) + else: + value = numerator / denominator + # used for `snapshort` to do kbest save. + report("VALID/LOSS", value) + logger.info(f"Valid: all-reduce loss {value}") def evaluate(self): # switch to eval mode @@ -56,9 +81,13 @@ class StandardEvaluator(extension.Extension): with scope(observation): # main evaluation computation here. with paddle.no_grad(): - self.evaluate_core(batch) + self.evaluate_sync(self.evaluate_core(batch)) summary.add(observation) summary = summary.compute_mean() + + # switch to train mode + for model in self.models.values(): + model.train() return summary def __call__(self, trainer=None): @@ -66,6 +95,7 @@ class StandardEvaluator(extension.Extension): # if it is used to extend a trainer, the metrics is reported to # to observation of the trainer # or otherwise, you can use your own observation - summary = self.evaluate() + with Timer("Eval Time Cost: {}"): + summary = self.evaluate() for k, v in summary.items(): report(k, v) diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py index cb4e6dfbf..1d3fe70cb 100644 --- a/deepspeech/training/extensions/snapshot.py +++ b/deepspeech/training/extensions/snapshot.py @@ -20,8 +20,9 @@ from typing import List import jsonlines -from deepspeech.training.extensions import extension -from deepspeech.training.updaters.trainer import Trainer +from . import extension +from ..reporter import get_observations +from ..updaters.trainer import Trainer from deepspeech.utils.log import Log from deepspeech.utils.mp_tools import rank_zero_only @@ -52,8 +53,19 @@ class Snapshot(extension.Extension): priority = -100 default_name = "snapshot" - def __init__(self, max_size: int=5, snapshot_on_error: bool=False): + def __init__(self, + mode='latest', + max_size: int=5, + indicator=None, + less_better=True, + snapshot_on_error: bool=False): self.records: List[Dict[str, Any]] = [] + assert mode in ('latest', 'kbest'), mode + if mode == 'kbest': + assert indicator is not None + self.mode = mode + self.indicator = indicator + self.less_is_better = less_better self.max_size = max_size self._snapshot_on_error = snapshot_on_error self._save_all = (max_size == -1) @@ -66,16 +78,17 @@ class Snapshot(extension.Extension): # load existing records record_path: Path = self.checkpoint_dir / "records.jsonl" if record_path.exists(): - logger.debug("Loading from an existing checkpoint dir") self.records = load_records(record_path) - trainer.updater.load(self.records[-1]['path']) + ckpt_path = self.records[-1]['path'] + logger.info(f"Loading from an existing checkpoint {ckpt_path}") + trainer.updater.load(ckpt_path) def on_error(self, trainer, exc, tb): if self._snapshot_on_error: - self.save_checkpoint_and_update(trainer) + self.save_checkpoint_and_update(trainer, 'latest') def __call__(self, trainer: Trainer): - self.save_checkpoint_and_update(trainer) + self.save_checkpoint_and_update(trainer, self.mode) def full(self): """Whether the number of snapshots it keeps track of is greater @@ -83,7 +96,7 @@ class Snapshot(extension.Extension): return (not self._save_all) and len(self.records) > self.max_size @rank_zero_only - def save_checkpoint_and_update(self, trainer: Trainer): + def save_checkpoint_and_update(self, trainer: Trainer, mode: str): """Saving new snapshot and remove the oldest snapshot if needed.""" iteration = trainer.updater.state.iteration epoch = trainer.updater.state.epoch @@ -97,11 +110,17 @@ class Snapshot(extension.Extension): 'path': str(path.resolve()), # use absolute path 'iteration': iteration, 'epoch': epoch, + 'indicator': get_observations()[self.indicator] } self.records.append(record) # remove the earist if self.full(): + if mode == 'kbest': + self.records = sorted( + self.records, + key=lambda record: record['indicator'], + reverse=not self.less_is_better) eariest_record = self.records[0] os.remove(eariest_record["path"]) self.records.pop(0) diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py index b69e94aaf..e5f456cac 100644 --- a/deepspeech/training/extensions/visualizer.py +++ b/deepspeech/training/extensions/visualizer.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from deepspeech.training.extensions import extension -from deepspeech.training.updaters.trainer import Trainer +from visualdl import LogWriter + +from . import extension +from ..updaters.trainer import Trainer class VisualDL(extension.Extension): @@ -26,8 +28,8 @@ class VisualDL(extension.Extension): default_name = 'visualdl' priority = extension.PRIORITY_READER - def __init__(self, writer): - self.writer = writer + def __init__(self, output_dir): + self.writer = LogWriter(str(output_dir)) def __call__(self, trainer: Trainer): for k, v in trainer.observation.items(): diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index f46814eb0..87b36acae 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -47,7 +47,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): sum_square = layers.reduce_sum(square) sum_square_list.append(sum_square) - # debug log + # debug log, not dump all since slow down train process if i < 10: logger.debug( f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") @@ -76,7 +76,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): new_grad = layers.elementwise_mul(x=g, y=clip_var) params_and_grads.append((p, new_grad)) - # debug log + # debug log, not dump all since slow down train process if i < 10: logger.debug( f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" diff --git a/deepspeech/training/timer.py b/deepspeech/training/timer.py new file mode 100644 index 000000000..2ca9d6386 --- /dev/null +++ b/deepspeech/training/timer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import time + +from deepspeech.utils.log import Log + +__all__ = ["Timer"] + +logger = Log(__name__).getlog() + + +class Timer(): + """To be used like this: + with Timer("Message") as value: + do some thing + """ + + def __init__(self, message=None): + self.message = message + + def duration(self) -> str: + elapsed_time = time.time() - self.start + time_str = str(datetime.timedelta(seconds=elapsed_time)) + return time_str + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, type, value, traceback): + if self.message: + logger.info(self.message.format(self.duration())) + + def __call__(self) -> float: + return time.time() - self.start + + def __str__(self): + return self.duration() diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 3a922c6f4..7959b41b8 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,6 +18,7 @@ import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter +from deepspeech.training.timer import Timer from deepspeech.utils import mp_tools from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log @@ -170,7 +171,7 @@ class Trainer(): self.iteration = 0 self.epoch = 0 scratch = True - + logger.info("Restore/Init checkpoint!") return scratch def new_epoch(self): @@ -194,35 +195,37 @@ class Trainer(): logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -240,14 +243,14 @@ class Trainer(): """The routine of the experiment after setup. This method is intended to be used by the user. """ - try: - self.train() - except KeyboardInterrupt: - self.save() - exit(-1) - finally: - self.destory() - logger.info("Training Done.") + with Timer("Training Done: {}"): + try: + self.train() + except KeyboardInterrupt: + self.save() + exit(-1) + finally: + self.destory() def setup_output_dir(self): """Create a directory used for output. diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py index fc758e93e..10c99e7fc 100644 --- a/deepspeech/training/updaters/standard_updater.py +++ b/deepspeech/training/updaters/standard_updater.py @@ -14,12 +14,12 @@ from typing import Dict from typing import Optional -from paddle import Tensor +import paddle from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from paddle.nn import Layer from paddle.optimizer import Optimizer -from timer import timer +from paddle.optimizer.lr import LRScheduler from deepspeech.training.reporter import report from deepspeech.training.updaters.updater import UpdaterBase @@ -39,8 +39,10 @@ class StandardUpdater(UpdaterBase): def __init__(self, model: Layer, optimizer: Optimizer, + scheduler: LRScheduler, dataloader: DataLoader, init_state: Optional[UpdaterState]=None): + super().__init__(init_state) # it is designed to hold multiple models models = {"main": model} self.models: Dict[str, Layer] = models @@ -51,15 +53,14 @@ class StandardUpdater(UpdaterBase): self.optimizer = optimizer self.optimizers: Dict[str, Optimizer] = optimizers + # it is designed to hold multiple scheduler + schedulers = {"main": scheduler} + self.scheduler = scheduler + self.schedulers: Dict[str, LRScheduler] = schedulers + # dataloaders self.dataloader = dataloader - # init state - if init_state is None: - self.state = UpdaterState() - else: - self.state = init_state - self.train_iterator = iter(dataloader) def update(self): @@ -103,8 +104,10 @@ class StandardUpdater(UpdaterBase): model.train() # training for a step is implemented here - batch = self.read_batch() - self.update_core(batch) + with Timier("data time cost:{}"): + batch = self.read_batch() + with Timier("step time cost:{}"): + self.update_core(batch) self.state.iteration += 1 if self.updates_per_epoch is not None: @@ -115,13 +118,14 @@ class StandardUpdater(UpdaterBase): """A simple case for a training step. Basic assumptions are: Single model; Single optimizer; + Single scheduler, and update learning rate each step; A batch from the dataloader is just the input of the model; The model return a single loss, or a dict containing serval losses. Parameters updates at every batch, no gradient accumulation. """ loss = self.model(*batch) - if isinstance(loss, Tensor): + if isinstance(loss, paddle.Tensor): loss_dict = {"main": loss} else: # Dict[str, Tensor] @@ -135,14 +139,15 @@ class StandardUpdater(UpdaterBase): for name, loss_item in loss_dict.items(): report(name, float(loss_item)) - self.optimizer.clear_gradient() + self.optimizer.clear_grad() loss_dict["main"].backward() - self.optimizer.update() + self.optimizer.step() + self.scheduler.step() @property def updates_per_epoch(self): - """Number of updater per epoch, determined by the length of the - dataloader.""" + """Number of steps per epoch, + determined by the length of the dataloader.""" length_of_dataloader = None try: length_of_dataloader = len(self.dataloader) @@ -163,18 +168,16 @@ class StandardUpdater(UpdaterBase): def read_batch(self): """Read a batch from the data loader, auto renew when data is exhausted.""" - with timer() as t: - try: - batch = next(self.train_iterator) - except StopIteration: - self.new_epoch() - batch = next(self.train_iterator) - logger.debug( - f"Read a batch takes {t.elapse}s.") # replace it with logger + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) return batch def state_dict(self): - """State dict of a Updater, model, optimizer and updater state are included.""" + """State dict of a Updater, model, optimizers/schedulers + and updater state are included.""" state_dict = super().state_dict() for name, model in self.models.items(): state_dict[f"{name}_params"] = model.state_dict() @@ -184,7 +187,7 @@ class StandardUpdater(UpdaterBase): def set_state_dict(self, state_dict): """Set state dict for a Updater. Parameters of models, states for - optimizers and UpdaterState are restored.""" + optimizers/schedulers and UpdaterState are restored.""" for name, model in self.models.items(): model.set_state_dict(state_dict[f"{name}_params"]) for name, optim in self.optimizers.items(): diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py index 954ce2604..a52fb9eb3 100644 --- a/deepspeech/training/updaters/trainer.py +++ b/deepspeech/training/updaters/trainer.py @@ -140,8 +140,8 @@ class Trainer(): try: while not stop_trigger(self): self.observation = {} - # set observation as the report target - # you can use report freely in Updater.update() + # set observation as the `report` target + # you can use `report` freely in Updater.update() # updating parameters and state with scope(self.observation): diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py index 66fdc2bbc..e5dd65563 100644 --- a/deepspeech/training/updaters/updater.py +++ b/deepspeech/training/updaters/updater.py @@ -52,6 +52,7 @@ class UpdaterBase(): """ def __init__(self, init_state=None): + # init state if init_state is None: self.state = UpdaterState() else: diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index a59f8be79..8e31edfae 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -114,13 +114,13 @@ class Checkpoint(): params_path = checkpoint_path + ".pdparams" model_dict = paddle.load(params_path) model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) + logger.info("Rank {}: Restore model from {}".format(rank, params_path)) optimizer_path = checkpoint_path + ".pdopt" if optimizer and os.path.isfile(optimizer_path): optimizer_dict = paddle.load(optimizer_path) optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( + logger.info("Rank {}: Restore optimizer state from {}".format( rank, optimizer_path)) info_path = re.sub('.pdparams$', '.json', params_path) diff --git a/deepspeech/utils/log.py b/deepspeech/utils/log.py index 3fd7d2480..7e8de600a 100644 --- a/deepspeech/utils/log.py +++ b/deepspeech/utils/log.py @@ -12,19 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import getpass -import logging import os import socket import sys +from loguru import logger from paddle import inference -FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' -DATE_FMT_STR = '%Y/%m/%d %H:%M:%S' - -logging.basicConfig( - level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR) - def find_log_dir(log_dir=None): """Returns the most suitable directory to put log files into. @@ -98,59 +92,28 @@ def find_log_dir_and_names(program_name=None, log_dir=None): class Log(): - - log_name = None - - def __init__(self, logger=None): - self.logger = logging.getLogger(logger) - self.logger.setLevel(logging.DEBUG) - - file_dir = os.getcwd() + '/log' - if not os.path.exists(file_dir): - os.mkdir(file_dir) - self.log_dir = file_dir - - actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( - program_name=None, log_dir=self.log_dir) - - basename = '%s.DEBUG.%d' % (file_prefix, os.getpid()) - filename = os.path.join(actual_log_dir, basename) - if Log.log_name is None: - Log.log_name = filename - - # Create a symlink to the log file with a canonical name. - symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG') - try: - if os.path.islink(symlink): - os.unlink(symlink) - os.symlink(os.path.basename(Log.log_name), symlink) - except EnvironmentError: - # If it fails, we're sad but it's no error. Commonly, this - # fails because the symlink was created by another user and so - # we can't modify it - pass - - if not self.logger.hasHandlers(): - formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR) - fh = logging.FileHandler(Log.log_name) - fh.setLevel(logging.DEBUG) - fh.setFormatter(formatter) - self.logger.addHandler(fh) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - self.logger.addHandler(ch) - - # stop propagate for propagating may print - # log multiple times - self.logger.propagate = False + """Default Logger for all.""" + logger.remove() + logger.add( + sys.stdout, + level='INFO', + enqueue=True, + filter=lambda record: record['level'].no >= 20) + _, file_prefix, _ = find_log_dir_and_names() + sink_prefix = os.path.join("exp/log", file_prefix) + sink_path = sink_prefix[:-3] + "{time}.log" + logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB") + + def __init__(self, name=None): + pass def getlog(self): - return self.logger + return logger class Autolog: + """Just used by fullchain project""" + def __init__(self, batch_size, model_name="DeepSpeech", diff --git a/doc/images/multi_gpu_speedup.png b/doc/images/multi_gpu_speedup.png deleted file mode 100755 index 286de5151..000000000 Binary files a/doc/images/multi_gpu_speedup.png and /dev/null differ diff --git a/doc/images/tuning_error_surface.png b/doc/images/tuning_error_surface.png deleted file mode 100644 index 2204cee2f..000000000 Binary files a/doc/images/tuning_error_surface.png and /dev/null differ diff --git a/doc/src/benchmark.md b/doc/src/benchmark.md deleted file mode 100644 index 9c1c86fd7..000000000 --- a/doc/src/benchmark.md +++ /dev/null @@ -1,16 +0,0 @@ -# Benchmarks - -## Acceleration with Multi-GPUs - -We compare the training time with 1, 2, 4, 8 Tesla V100 GPUs (with a subset of LibriSpeech samples whose audio durations are between 6.0 and 7.0 seconds). And it shows that a **near-linear** acceleration with multiple GPUs has been achieved. In the following figure, the time (in seconds) cost for training is printed on the blue bars. - - - -| # of GPU | Acceleration Rate | -| -------- | --------------: | -| 1 | 1.00 X | -| 2 | 1.98 X | -| 4 | 3.73 X | -| 8 | 6.95 X | - -`utils/profile.sh` provides such a demo profiling tool, you can change it as need. diff --git a/doc/images/ds2offlineModel.png b/docs/images/ds2offlineModel.png similarity index 100% rename from doc/images/ds2offlineModel.png rename to docs/images/ds2offlineModel.png diff --git a/doc/images/ds2onlineModel.png b/docs/images/ds2onlineModel.png similarity index 100% rename from doc/images/ds2onlineModel.png rename to docs/images/ds2onlineModel.png diff --git a/doc/src/augmentation.md b/docs/src/augmentation.md similarity index 100% rename from doc/src/augmentation.md rename to docs/src/augmentation.md diff --git a/doc/src/data_preparation.md b/docs/src/data_preparation.md similarity index 100% rename from doc/src/data_preparation.md rename to docs/src/data_preparation.md diff --git a/doc/src/deepspeech_architecture.md b/docs/src/deepspeech_architecture.md similarity index 95% rename from doc/src/deepspeech_architecture.md rename to docs/src/deepspeech_architecture.md index 04c7bee79..6c4951897 100644 --- a/doc/src/deepspeech_architecture.md +++ b/docs/src/deepspeech_architecture.md @@ -1,8 +1,8 @@ # Deepspeech2 ## Streaming -The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes. -The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers. +The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes. +The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers. To illustrate the model implementation clearly, 3 parts are described in detail. - Data Preparation @@ -11,10 +11,10 @@ To illustrate the model implementation clearly, 3 parts are described in detail. In addition, the training process and the testing process are also introduced. -The arcitecture of the model is shown in Fig.1. +The arcitecture of the model is shown in Fig.1.

- +
Fig.1 The Arcitecture of deepspeech2 online model

@@ -28,17 +28,17 @@ For English data, the vocabulary dictionary is composed of 26 English characters --unit_type="char" \ --count_threshold=0 \ --vocab_path="data/vocab.txt" \ - --manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw" - + --manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw" + # vocabulary for aishell dataset (Mandarin) vi examples/aishell/s0/data/vocab.txt - + # vocabulary for librispeech dataset (English) vi examples/librispeech/s0/data/vocab.txt ``` #### CMVN -For CMVN, a subset or the full of traininig set is chosed and be used to compute the feature mean and std. +For CMVN, a subset or the full of traininig set is chosed and be used to compute the feature mean and std. ``` # The code to compute the feature mean and std cd examples/aishell/s0 @@ -52,16 +52,16 @@ python3 ../../../utils/compute_mean_std.py \ --use_dB_normalization=True \ --num_samples=2000 \ --num_workers=10 \ - --output_path="data/mean_std.json" + --output_path="data/mean_std.json" ``` - + #### Feature Extraction For feature extraction, three methods are implemented, which are linear (FFT without using filter bank), fbank and mfcc. Currently, the released deepspeech2 online model use the linear feature extraction method. ``` The code for feature extraction - vi deepspeech/frontend/featurizer/audio_featurizer.py + vi deepspeech/frontend/featurizer/audio_featurizer.py ``` ### Encoder @@ -70,7 +70,7 @@ The code of Encoder is in: ``` vi deepspeech/models/ds2_online/deepspeech2.py ``` - + ### Decoder To got the character possibilities of each frame, the feature represention of each frame output from the encoder are input into a projection layer which is implemented as a dense layer to do feature projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to transform the frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results. The code of Decoder is in: @@ -80,7 +80,7 @@ vi deepspeech/models/ds2_online/deepspeech2.py # The code of CTC Decoder vi deepspeech/modules/ctc.py ``` - + ## Training Process Using the command below, you can train the deepspeech2 online model. ``` @@ -121,8 +121,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then avg.sh exp/${ckpt}/checkpoints ${avg_num} fi ``` + By using the command above, the training process can be started. There are 5 stages in "run.sh", and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss. - + ## Testing Process Using the command below, you can test the deepspeech2 online model. ``` @@ -131,7 +132,7 @@ Using the command below, you can test the deepspeech2 online model. The detail commands are: ``` conf_path=conf/deepspeech2_online.yaml -avg_num=1 +avg_num=1 model_type=online avg_ckpt=avg_${avg_num} @@ -152,29 +153,29 @@ fi ``` After the training process, we use stage 3,4,5 for testing process. The stage 3 is for testing the model generated in the stage 2 and provided the CER index of the test set. The stage 4 is for transforming the model from dynamic graph to static graph by using "paddle.jit" library. The stage 5 is for testing the model in static graph. - + ## Non-Streaming The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the stacked bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used. For the stacked bi-directional rnn layers in the offline model, the rnn cell and gru cell are provided to use. The arcitecture of the model is shown in Fig.2.

- +
Fig.2 The Arcitecture of deepspeech2 offline model

- + For data preparation and decoder, the deepspeech2 offline model is same with the deepspeech2 online model. The code of encoder and decoder for deepspeech2 offline model is in: ``` vi deepspeech/models/ds2/deepspeech2.py ``` - + The training process and testing process of deepspeech2 offline model is very similary to deepspeech2 online model. Only some changes should be noticed. -For training and testing, the "model_type" and the "conf_path" must be set. +For training and testing, the "model_type" and the "conf_path" must be set. ``` # Training offline cd examples/aishell/s0 @@ -185,5 +186,3 @@ bash run.sh --stage 0 --stop_stage 2 --model_type offline --conf_path conf/deeps cd examples/aishell/s0 bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml ``` - - diff --git a/doc/src/feature_list.md b/docs/src/feature_list.md similarity index 100% rename from doc/src/feature_list.md rename to docs/src/feature_list.md diff --git a/doc/src/getting_started.md b/docs/src/getting_started.md similarity index 100% rename from doc/src/getting_started.md rename to docs/src/getting_started.md diff --git a/doc/src/install.md b/docs/src/install.md similarity index 100% rename from doc/src/install.md rename to docs/src/install.md diff --git a/doc/src/ngram_lm.md b/docs/src/ngram_lm.md similarity index 100% rename from doc/src/ngram_lm.md rename to docs/src/ngram_lm.md diff --git a/doc/src/reference.md b/docs/src/reference.md similarity index 100% rename from doc/src/reference.md rename to docs/src/reference.md diff --git a/doc/src/released_model.md b/docs/src/released_model.md similarity index 100% rename from doc/src/released_model.md rename to docs/src/released_model.md diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 7f0a1462f..9560930ac 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -40,9 +40,12 @@ model: rnn_layer_size: 1024 use_gru: True share_rnn_weights: False + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 80 + accum_grad: 1 lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index fdc3a5365..7e87594cc 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -36,17 +36,20 @@ collator: model: num_conv_layers: 2 - num_rnn_layers: 3 + num_rnn_layers: 5 rnn_layer_size: 1024 rnn_direction: forward # [forward, bidirect] - num_fc_layers: 1 - fc_layers_size_list: 512, + num_fc_layers: 0 + fc_layers_size_list: -1, use_gru: False - + blank_id: 0 + ctc_grad_norm_type: instance + training: n_epoch: 50 + accum_grad: 1 lr: 2e-3 - lr_decay: 0.91 # 0.83 + lr_decay: 0.9 # 0.83 weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 @@ -59,7 +62,7 @@ decoding: error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm - alpha: 1.9 + alpha: 2.2 #1.9 beta: 5.0 beam_size: 300 cutoff_prob: 0.99 diff --git a/examples/aishell/s0/local/client.sh b/examples/aishell/s0/local/client.sh deleted file mode 100755 index 3b59ad3df..000000000 --- a/examples/aishell/s0/local/client.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -source path.sh - -# run on MacOS -# brew install portaudio -# pip install pyaudio -# pip install keyboard - -# start demo client -python3 -u ${BIN_DIR}/deploy/client.py \ ---host_ip="localhost" \ ---host_port=8086 \ - -if [ $? -ne 0 ]; then - echo "Failed in starting demo client!" - exit 1 -fi - -exit 0 diff --git a/examples/aishell/s0/local/server.sh b/examples/aishell/s0/local/server.sh deleted file mode 100755 index 2b8810993..000000000 --- a/examples/aishell/s0/local/server.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash -# TODO: replace the model with a mandarin model - -if [[ $# != 1 ]];then - echo "usage: $1 checkpoint_path" - exit -1 -fi - -source path.sh - -# download language model -bash local/download_lm_ch.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -# download well-trained model -#bash local/download_model.sh -#if [ $? -ne 0 ]; then -# exit 1 -#fi - -# start demo server -CUDA_VISIBLE_DEVICES=0 \ -python3 -u ${BIN_DIR}/deploy/server.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---host_ip="localhost" \ ---host_port=8086 \ ---speech_save_dir="demo_cache" \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in starting demo server!" - exit 1 -fi - - -exit 0 diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index 3438a7357..85d1d42c3 100755 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -20,7 +20,7 @@ fi mkdir -p exp seed=10086 -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -32,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \ --model_type ${model_type} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/aishell/s0/local/tune.sh b/examples/aishell/s0/local/tune.sh deleted file mode 100755 index 59406cd5b..000000000 --- a/examples/aishell/s0/local/tune.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -# grid-search for hyper-parameters in language model -python3 -u ${BIN_DIR}/tune.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---num_batches=10 \ ---batch_size=128 \ ---beam_size=300 \ ---num_proc_bsearch=8 \ ---num_alphas=10 \ ---num_betas=10 \ ---alpha_from=0.0 \ ---alpha_to=5.0 \ ---beta_from=-6 \ ---beta_to=6 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in tuning!" - exit 1 -fi - - -exit 0 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index e5ab12a59..71191c3ac 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -27,7 +27,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 3e606788e..6f8ae135f 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 4b1430c58..a4248459c 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -71,6 +71,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/aishell/s1/local/train.sh b/examples/aishell/s1/local/train.sh index ec17054ab..2861e11ec 100755 --- a/examples/aishell/s1/local/train.sh +++ b/examples/aishell/s1/local/train.sh @@ -19,8 +19,8 @@ echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0]; then export FLAGS_cudnn_deterministic=True fi @@ -31,7 +31,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index d55d47ea6..e3c008234 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/callcenter/s1/local/train.sh b/examples/callcenter/s1/local/train.sh index 928c6492c..6e63df83a 100755 --- a/examples/callcenter/s1/local/train.sh +++ b/examples/callcenter/s1/local/train.sh @@ -19,8 +19,8 @@ echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0]; then export FLAGS_cudnn_deterministic=True fi @@ -31,7 +31,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/callcenter/s1/run.sh b/examples/callcenter/s1/run.sh index 52dd44eca..305021f19 100644 --- a/examples/callcenter/s1/run.sh +++ b/examples/callcenter/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/cc-cedict/README.md b/examples/cc-cedict/README.md index e69de29bb..513fca533 100644 --- a/examples/cc-cedict/README.md +++ b/examples/cc-cedict/README.md @@ -0,0 +1,58 @@ +# [CC-CEDICT](https://cc-cedict.org/wiki/) + +What is CC-CEDICT? +CC-CEDICT is a continuation of the CEDICT project. +The objective of the CEDICT project was to create an online, downloadable (as opposed to searchable-only) public-domain Chinese-English dictionary. +CEDICT was started by Paul Andrew Denisowski in October 1997. +For the most part, the project is modeled on Jim Breen's highly successful EDICT (Japanese-English dictionary) project and is intended to be a collaborative effort, +with users providing entries and corrections to the main file. + + +## Parse CC-CEDICT to Json format + +1. Parse to Json + +``` +run.sh +``` + +2. Result + +``` +exp/ +|-- cedict +`-- cedict.json + +0 directories, 2 files +``` + +``` +4c4bffc84e24467fe1b2ea9ba37ed6b6 exp/cedict +3adf504dacd13886f88cc9fe3b37c75d exp/cedict.json +``` + +``` +==> exp/cedict <== +# CC-CEDICT +# Community maintained free Chinese-English dictionary. +# +# Published by MDBG +# +# License: +# Creative Commons Attribution-ShareAlike 4.0 International License +# https://creativecommons.org/licenses/by-sa/4.0/ +# +# Referenced works: + +==> exp/cedict.json <== +{"traditional": "2019\u51a0\u72c0\u75c5\u6bd2\u75c5", "simplified": "2019\u51a0\u72b6\u75c5\u6bd2\u75c5", "pinyin": "er4 ling2 yi1 jiu3 guan1 zhuang4 bing4 du2 bing4", "english": "COVID-19, the coronavirus disease identified in 2019"} +{"traditional": "21\u4e09\u9ad4\u7d9c\u5408\u75c7", "simplified": "21\u4e09\u4f53\u7efc\u5408\u75c7", "pinyin": "er4 shi2 yi1 san1 ti3 zong1 he2 zheng4", "english": "trisomy"} +{"traditional": "3C", "simplified": "3C", "pinyin": "san1 C", "english": "abbr. for computers, communications, and consumer electronics"} +{"traditional": "3P", "simplified": "3P", "pinyin": "san1 P", "english": "(slang) threesome"} +{"traditional": "3Q", "simplified": "3Q", "pinyin": "san1 Q", "english": "(Internet slang) thank you (loanword)"} +{"traditional": "421", "simplified": "421", "pinyin": "si4 er4 yi1", "english": "four grandparents, two parents and an only child"} +{"traditional": "502\u81a0", "simplified": "502\u80f6", "pinyin": "wu3 ling2 er4 jiao1", "english": "cyanoacrylate glue"} +{"traditional": "88", "simplified": "88", "pinyin": "ba1 ba1", "english": "(Internet slang) bye-bye (alternative for \u62dc\u62dc[bai2 bai2])"} +{"traditional": "996", "simplified": "996", "pinyin": "jiu3 jiu3 liu4", "english": "9am-9pm, six days a week (work schedule)"} +{"traditional": "A", "simplified": "A", "pinyin": "A", "english": "(slang) (Tw) to steal"} +``` diff --git a/examples/chinese_g2p/README.md b/examples/chinese_g2p/README.md deleted file mode 100644 index e3fdfe684..000000000 --- a/examples/chinese_g2p/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Download Baker dataset - -Baker dataset has to be downloaded mannually and moved to 'data/', because you will have to pass the CATTCHA from a browswe to download the dataset. - -Download URL https://test.data-baker.com/#/data/index/source. diff --git a/examples/chinese_g2p/.gitignore b/examples/g2p/.gitignore similarity index 100% rename from examples/chinese_g2p/.gitignore rename to examples/g2p/.gitignore diff --git a/examples/g2p/README.md b/examples/g2p/README.md new file mode 100644 index 000000000..4ec5922b3 --- /dev/null +++ b/examples/g2p/README.md @@ -0,0 +1,3 @@ +# G2P + +* zh - Chinese G2P diff --git a/examples/g2p/zh/README.md b/examples/g2p/zh/README.md new file mode 100644 index 000000000..de5573565 --- /dev/null +++ b/examples/g2p/zh/README.md @@ -0,0 +1,93 @@ +# G2P + +* WS +jieba +* G2P +pypinyin +* Tone sandhi +simple + +We recommend using [Paraket](https://github.com/PaddlePaddle/Parakeet] [TextFrontEnd](https://github.com/PaddlePaddle/Parakeet/blob/develop/parakeet/frontend/__init__.py) to do G2P. +The phoneme set should be changed, you can reference `examples/thchs30/a0/data/dict/syllable.lexicon`. + +## Download Baker dataset + +[Baker](https://test.data-baker.com/#/data/index/source) dataset has to be downloaded mannually and moved to './data', +because you will have to pass the `CATTCHA` from a browswe to download the dataset. + + +## RUN + +``` +. path.sh +./run.sh +``` + +## Result + +``` +exp/ +|-- 000001-010000.txt +|-- ref.pinyin +|-- trans.jieba.pinyin +`-- trans.pinyin + +0 directories, 4 files +``` + +``` +4f5a368441eb16aaf43dc1972f8b63dd exp/000001-010000.txt +01707896391c2de9b6fc4a39654be942 exp/ref.pinyin +43380ef160f65a23a3a0544700aa49b8 exp/trans.jieba.pinyin +8e6ff1fc22d8e8584082e804e8bcdeb7 exp/trans.pinyin +``` + +``` +==> exp/000001-010000.txt <== +000001 卡尔普#2陪外孙#1玩滑梯#4。 + ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 假语村言#2别再#1拥抱我#4。 + jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 宝马#1配挂#1跛骡鞍#3,貂蝉#1怨枕#2董翁榻#4。 + bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 邓小平#2与#1撒切尔#2会晤#4。 + deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4 +000005 老虎#1幼崽#2与#1宠物犬#1玩耍#4。 + lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3 + +==> exp/ref.pinyin <== +000001 ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4 +000005 lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu2 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan2 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi2 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 + +==> exp/trans.jieba.pinyin <== +000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4 +000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 + +==> exp/trans.pinyin <== +000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4 +000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 +``` diff --git a/examples/chinese_g2p/local/convert_transcription.py b/examples/g2p/zh/local/convert_transcription.py similarity index 100% rename from examples/chinese_g2p/local/convert_transcription.py rename to examples/g2p/zh/local/convert_transcription.py diff --git a/examples/chinese_g2p/local/extract_pinyin_label.py b/examples/g2p/zh/local/extract_pinyin_label.py similarity index 100% rename from examples/chinese_g2p/local/extract_pinyin_label.py rename to examples/g2p/zh/local/extract_pinyin_label.py diff --git a/examples/chinese_g2p/local/ignore_sandhi.py b/examples/g2p/zh/local/ignore_sandhi.py similarity index 100% rename from examples/chinese_g2p/local/ignore_sandhi.py rename to examples/g2p/zh/local/ignore_sandhi.py diff --git a/examples/chinese_g2p/local/prepare_dataset.sh b/examples/g2p/zh/local/prepare_dataset.sh similarity index 100% rename from examples/chinese_g2p/local/prepare_dataset.sh rename to examples/g2p/zh/local/prepare_dataset.sh diff --git a/examples/chinese_g2p/path.sh b/examples/g2p/zh/path.sh similarity index 82% rename from examples/chinese_g2p/path.sh rename to examples/g2p/zh/path.sh index 482177dc6..f475ed833 100644 --- a/examples/chinese_g2p/path.sh +++ b/examples/g2p/zh/path.sh @@ -1,4 +1,4 @@ -export MAIN_ROOT=`realpath ${PWD}/../../` +export MAIN_ROOT=`realpath ${PWD}/../../../` export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C diff --git a/examples/chinese_g2p/requirements.txt b/examples/g2p/zh/requirements.txt similarity index 100% rename from examples/chinese_g2p/requirements.txt rename to examples/g2p/zh/requirements.txt diff --git a/examples/chinese_g2p/run.sh b/examples/g2p/zh/run.sh similarity index 82% rename from examples/chinese_g2p/run.sh rename to examples/g2p/zh/run.sh index 8197dce4b..25b713110 100755 --- a/examples/chinese_g2p/run.sh +++ b/examples/g2p/zh/run.sh @@ -6,16 +6,19 @@ stage=-1 stop_stage=100 exp_dir=exp -data_dir=data +data=data source ${MAIN_ROOT}/utils/parse_options.sh || exit -1 mkdir -p ${exp_dir} +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ];then + test -e ${data}/BZNSYP.rar || { echo "Please download BZNSYP.rar and put it in ${data}; exit -1; } +fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then echo "stage 0: Extracting Prosody Labeling" - bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data_dir} + bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data} fi # convert transcription in chinese into pinyin with pypinyin or jieba+pypinyin diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 5603d3c8a..11bcf5f65 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -1,10 +1,17 @@ # LibriSpeech +## Data +| Data Subset | Duration in Seconds | +| --- | --- | +| data/manifest.train | 0.83s ~ 29.735s | +| data/manifest.dev | 1.065 ~ 35.155s | +| data/manifest.test-clean | 1.285s ~ 34.955s | + ## Deepspeech2 | Model | Params | release | Config | Test set | Loss | WER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | 14.49190807 | test-clean | 0.067283 | -| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | -| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | +| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | test-clean | 14.49190807 | 0.067283 | +| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | test-clean | 15.184467315673828 | 0.072154 | +| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | test-clean | - | 0.073973 | | DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 | diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index dab8d0462..d5b1ed919 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -4,14 +4,14 @@ data: dev_manifest: data/manifest.dev-clean test_manifest: data/manifest.test-clean min_input_len: 0.0 - max_input_len: 27.0 # second + max_input_len: 30.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf collator: - batch_size: 20 + batch_size: 15 mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -40,9 +40,12 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 + accum_grad: 4 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml index 2e4aed40a..180a6205f 100644 --- a/examples/librispeech/s0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -4,14 +4,14 @@ data: dev_manifest: data/manifest.dev-clean test_manifest: data/manifest.test-clean min_input_len: 0.0 - max_input_len: 27.0 # second + max_input_len: 30.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf collator: - batch_size: 20 + batch_size: 15 mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -42,9 +42,12 @@ model: num_fc_layers: 2 fc_layers_size_list: 512, 256 use_gru: False + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 + accum_grad: 4 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index dcd21df34..c95659acf 100755 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -20,8 +20,8 @@ echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -33,7 +33,7 @@ python3 -u ${BIN_DIR}/train.py \ --model_type ${model_type} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/librispeech/s0/local/tune.sh b/examples/librispeech/s0/local/tune.sh deleted file mode 100755 index c344e77e5..000000000 --- a/examples/librispeech/s0/local/tune.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -if [ $# != 1 ];then - echo "usage: tune ckpt_path" - exit 1 -fi - -# grid-search for hyper-parameters in language model -python3 -u ${BIN_DIR}/tune.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---num_batches=-1 \ ---batch_size=128 \ ---beam_size=500 \ ---num_proc_bsearch=12 \ ---num_alphas=45 \ ---num_betas=8 \ ---alpha_from=1.0 \ ---alpha_to=3.2 \ ---beta_from=0.1 \ ---beta_to=0.45 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in tuning!" - exit 1 -fi - - -exit 0 diff --git a/examples/librispeech/s0/run.sh b/examples/librispeech/s0/run.sh index c7902a56a..af47fb9b8 100755 --- a/examples/librispeech/s0/run.sh +++ b/examples/librispeech/s0/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/librispeech/s1/conf/chunk_conformer.yaml b/examples/librispeech/s1/conf/chunk_conformer.yaml index 0de1aefee..92db20f66 100644 --- a/examples/librispeech/s1/conf/chunk_conformer.yaml +++ b/examples/librispeech/s1/conf/chunk_conformer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index f782a0373..e0bc3135e 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -69,6 +69,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index 6d825f05b..78be249cb 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -72,6 +72,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index bc2ec6061..4aa7b9158 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -33,7 +33,7 @@ collator: keep_transcription_text: False sortagrad: True shuffle_method: batch_shuffle - num_workers: 2 + num_workers: 0 # network architecture @@ -67,6 +67,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s1/local/train.sh b/examples/librispeech/s1/local/train.sh index ec17054ab..17a9e28df 100755 --- a/examples/librispeech/s1/local/train.sh +++ b/examples/librispeech/s1/local/train.sh @@ -19,8 +19,8 @@ echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -31,7 +31,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/librispeech/s1/run.sh b/examples/librispeech/s1/run.sh index def10ab05..aecd3f617 100755 --- a/examples/librispeech/s1/run.sh +++ b/examples/librispeech/s1/run.sh @@ -24,7 +24,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/librispeech/s2/conf/chunk_conformer.yaml b/examples/librispeech/s2/conf/chunk_conformer.yaml index 0de1aefee..92db20f66 100644 --- a/examples/librispeech/s2/conf/chunk_conformer.yaml +++ b/examples/librispeech/s2/conf/chunk_conformer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s2/conf/chunk_transformer.yaml b/examples/librispeech/s2/conf/chunk_transformer.yaml index f782a0373..e0bc3135e 100644 --- a/examples/librispeech/s2/conf/chunk_transformer.yaml +++ b/examples/librispeech/s2/conf/chunk_transformer.yaml @@ -69,6 +69,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s2/conf/conformer.yaml b/examples/librispeech/s2/conf/conformer.yaml index 955b6108b..9a7274135 100644 --- a/examples/librispeech/s2/conf/conformer.yaml +++ b/examples/librispeech/s2/conf/conformer.yaml @@ -72,6 +72,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index f7c27d1f7..edf5b81dc 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -22,7 +22,7 @@ collator: batch_frames_out: 0 batch_frames_inout: 0 augmentation_config: conf/augmentation.json - num_workers: 2 + num_workers: 0 subsampling_factor: 1 num_encs: 1 @@ -58,6 +58,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/librispeech/s2/local/train.sh b/examples/librispeech/s2/local/train.sh index c75252594..a75e2bb26 100755 --- a/examples/librispeech/s2/local/train.sh +++ b/examples/librispeech/s2/local/train.sh @@ -19,8 +19,8 @@ echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -32,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index 26398dd14..46c8ea5d8 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/ngram_lm/READEME.md b/examples/ngram_lm/READEME.md new file mode 100644 index 000000000..84e1380c3 --- /dev/null +++ b/examples/ngram_lm/READEME.md @@ -0,0 +1,3 @@ +# Ngram LM + +* s0 - kenlm ngram lm diff --git a/examples/ngram_lm/s0/.gitignore b/examples/ngram_lm/s0/.gitignore new file mode 100644 index 000000000..b20d93aa5 --- /dev/null +++ b/examples/ngram_lm/s0/.gitignore @@ -0,0 +1 @@ +data/lm diff --git a/examples/ngram_lm/s0/README.md b/examples/ngram_lm/s0/README.md index 698d7c290..65916ec54 100644 --- a/examples/ngram_lm/s0/README.md +++ b/examples/ngram_lm/s0/README.md @@ -2,6 +2,95 @@ Train chinese chararctor ngram lm by [kenlm](https://github.com/kpu/kenlm). +## Run ``` +. path.sh bash run.sh ``` + +## Results + +``` +exp/ +|-- text +|-- text.char.tn +|-- text.word.tn +|-- text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa +|-- text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa.klm.bin +|-- text_zh_word_o3_p0_0_0_a22_q8_b8.arpa +`-- text_zh_word_o3_p0_0_0_a22_q8_b8.arpa.klm.bin + +0 directories, 7 files +``` + +``` +3ae083627b9b6cef1a82d574d8483f97 exp/text +d97da252d2a63a662af22f98af30cb8c exp/text.char.tn +c18b03005bd094dbfd9b46442be361fd exp/text.word.tn +73dbf50097896eda33985e11e1ba9a3a exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa +01334e2044c474b99c4f2ffbed790626 exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa.klm.bin +36a42de548045b54662411ae7982c77f exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa +332422803ffd73dd7ffd16cd2b0abcd5 exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa.klm.bin +``` + +``` +==> exp/text <== +少先队员因该为老人让坐 +祛痘印可以吗?有效果吗? +不知这款牛奶口感怎样? 小孩子喝行吗! +是转基因油? +我家宝宝13斤用多大码的 +会起坨吗? +请问给送上楼吗? +亲是送赁上门吗 +送货时候有外包装没有还是直接发货过来 +会不会有坏的? + +==> exp/text.char.tn <== +少 先 队 员 因 该 为 老 人 让 坐 +祛 痘 印 可 以 吗 有 效 果 吗 +不 知 这 款 牛 奶 口 感 怎 样 小 孩 子 喝 行 吗 +是 转 基 因 油 +我 家 宝 宝 十 三 斤 用 多 大 码 的 +会 起 坨 吗 +请 问 给 送 上 楼 吗 +亲 是 送 赁 上 门 吗 +送 货 时 候 有 外 包 装 没 有 还 是 直 接 发 货 过 来 +会 不 会 有 坏 的 + +==> exp/text.word.tn <== +少先队员 因该 为 老人 让 坐 +祛痘 印 可以 吗 有 效果 吗 +不知 这 款 牛奶 口感 怎样 小孩子 喝行 吗 +是 转基因 油 +我家 宝宝 十三斤 用多大码 的 +会起 坨 吗 +请问 给 送 上楼 吗 +亲是 送赁 上门 吗 +送货 时候 有 外包装 没有 还是 直接 发货 过来 +会 不会 有坏 的 + +==> exp/text_zh_char_o5_p0_1_2_4_4_a22_q8_b8.arpa <== +\data\ +ngram 1=587 +ngram 2=395 +ngram 3=100 +ngram 4=2 +ngram 5=0 + +\1-grams: +-3.272324 0 +0 -0.36706257 + +==> exp/text_zh_word_o3_p0_0_0_a22_q8_b8.arpa <== +\data\ +ngram 1=689 +ngram 2=1398 +ngram 3=1506 + +\1-grams: +-3.1755018 0 +0 -0.23069073 +-1.2318869 0 +-3.067262 少先队员 -0.051341705 +``` diff --git a/examples/punctuation_restoration/README.md b/examples/punctuation_restoration/README.md index f2ca76996..42ae0db3a 100644 --- a/examples/punctuation_restoration/README.md +++ b/examples/punctuation_restoration/README.md @@ -1,3 +1,3 @@ # Punctation Restoration -Please using `https://github.com/745165806/PaddleSpeechTask` to do this task. +Please using [PaddleSpeechTask](https://github.com/745165806/PaddleSpeechTask] to do this task. diff --git a/examples/spm/README.md b/examples/spm/README.md index 3109d3ffb..fc4478ebb 100644 --- a/examples/spm/README.md +++ b/examples/spm/README.md @@ -1,7 +1,96 @@ # [SentencePiece Model](https://github.com/google/sentencepiece) +## Run Train a `spm` model for English tokenizer. ``` +. path.sh bash run.sh ``` + +## Results + +``` +data/ +└── lang_char + ├── input.bpe + ├── input.decode + ├── input.txt + ├── train_unigram100.model + ├── train_unigram100_units.txt + └── train_unigram100.vocab + +1 directory, 6 files +``` + +``` +b5a230c26c61db5c36f34e503102f936 data/lang_char/input.bpe +ec5a9b24acc35469229e41256ceaf77d data/lang_char/input.decode +ec5a9b24acc35469229e41256ceaf77d data/lang_char/input.txt +124bf3fe7ce3b73b1994234c15268577 data/lang_char/train_unigram100.model +0df2488cc8eaace95eb12713facb5cf0 data/lang_char/train_unigram100_units.txt +46360cac35c751310e8e8ffd3a034cb5 data/lang_char/train_unigram100.vocab +``` + +``` +==> data/lang_char/input.bpe <== +▁mi ster ▁quilter ▁ is ▁the ▁a p ost le ▁o f ▁the ▁mi d d le ▁c las s es ▁ and ▁we ▁ar e ▁g l a d ▁ to ▁we l c om e ▁h is ▁g o s pe l +▁ n or ▁ is ▁mi ster ▁quilter ' s ▁ma nne r ▁ l ess ▁in ter es t ing ▁tha n ▁h is ▁ma t ter +▁h e ▁ t e ll s ▁us ▁tha t ▁ at ▁ t h is ▁f es t ive ▁ s e ason ▁o f ▁the ▁ y e ar ▁w ith ▁ ch r is t m a s ▁ and ▁ro a s t ▁be e f ▁ l o om ing ▁be fore ▁us ▁ s i mile s ▁d r a w n ▁f r om ▁ e at ing ▁ and ▁it s ▁re s u l t s ▁o c c ur ▁m ost ▁re a di l y ▁ to ▁the ▁ mind +▁h e ▁ ha s ▁g r a v e ▁d o u b t s ▁w h e t h er ▁ s i r ▁f r e d er ic k ▁ l eig h to n ' s ▁w or k ▁ is ▁re all y ▁gre e k ▁a f ter ▁ all ▁ and ▁c a n ▁di s c o v er ▁in ▁it ▁b u t ▁li t t le ▁o f ▁ro ck y ▁it ha c a +▁li nne ll ' s ▁ p ic tur es ▁ar e ▁a ▁ s or t ▁o f ▁ u p ▁g u ar d s ▁ and ▁ at ▁ em ▁painting s ▁ and ▁m ason ' s ▁ e x q u is i t e ▁ i d y ll s ▁ar e ▁a s ▁ n at ion a l ▁a s ▁a ▁ j ing o ▁ p o em ▁mi ster ▁b i r k e t ▁f o ster ' s ▁ l and s c a pe s ▁ s mile ▁ at ▁on e ▁m u ch ▁in ▁the ▁ s a m e ▁w a y ▁tha t ▁mi ster ▁c ar k er ▁us e d ▁ to ▁f las h ▁h is ▁ t e e t h ▁ and ▁mi ster ▁ j o h n ▁c o ll i er ▁g ive s ▁h is ▁ s i t ter ▁a ▁ ch e er f u l ▁ s l a p ▁on ▁the ▁b a ck ▁be fore ▁h +e ▁ s a y s ▁li k e ▁a ▁ s ha m p o o er ▁in ▁a ▁ tur k is h ▁b at h ▁ n e x t ▁ma n +▁it ▁ is ▁o b v i o u s l y ▁ u nne c ess ar y ▁for ▁us ▁ to ▁ p o i n t ▁o u t ▁h o w ▁ l u m i n o u s ▁the s e ▁c rit ic is m s ▁ar e ▁h o w ▁d e l ic at e ▁in ▁ e x p r ess ion +▁on ▁the ▁g e n er a l ▁ p r i n c i p l es ▁o f ▁ar t ▁mi ster ▁quilter ▁w rit es ▁w ith ▁ e qual ▁ l u c i di t y +▁painting ▁h e ▁ t e ll s ▁us ▁ is ▁o f ▁a ▁di f f er e n t ▁ qual i t y ▁ to ▁ma t h em at ic s ▁ and ▁f i nish ▁in ▁ar t ▁ is ▁a d d ing ▁m or e ▁f a c t +▁a s ▁for ▁ e t ch ing s ▁the y ▁ar e ▁o f ▁ t w o ▁ k i n d s ▁b rit is h ▁ and ▁for eig n +▁h e ▁ l a ment s ▁m ost ▁b i t ter l y ▁the ▁di v or c e ▁tha t ▁ ha s ▁be e n ▁ma d e ▁be t w e e n ▁d e c or at ive ▁ar t ▁ and ▁w ha t ▁we ▁us u all y ▁c all ▁ p ic tur es ▁ma k es ▁the ▁c u s t om ar y ▁a p pe a l ▁ to ▁the ▁ las t ▁ j u d g ment ▁ and ▁re mind s ▁us ▁tha t ▁in ▁the ▁gre at ▁d a y s ▁o f ▁ar t ▁mi c ha e l ▁a n g e l o ▁w a s ▁the ▁f ur nish ing ▁ u p h o l ster er + +==> data/lang_char/input.decode <== +mister quilter is the apostle of the middle classes and we are glad to welcome his gospel +nor is mister quilter's manner less interesting than his matter +he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind +he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca +linnell's pictures are a sort of up guards and at em paintings and mason's exquisite idylls are as national as a jingo poem mister birket foster's landscapes smile at one much in the same way that mister carker used to flash his teeth and mister john collier gives his sitter a cheerful slap on the back before he says like a shampooer in a turkish bath next man +it is obviously unnecessary for us to point out how luminous these criticisms are how delicate in expression +on the general principles of art mister quilter writes with equal lucidity +painting he tells us is of a different quality to mathematics and finish in art is adding more fact +as for etchings they are of two kinds british and foreign +he laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes the customary appeal to the last judgment and reminds us that in the great days of art michael angelo was the furnishing upholsterer + +==> data/lang_char/input.txt <== +mister quilter is the apostle of the middle classes and we are glad to welcome his gospel +nor is mister quilter's manner less interesting than his matter +he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind +he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca +linnell's pictures are a sort of up guards and at em paintings and mason's exquisite idylls are as national as a jingo poem mister birket foster's landscapes smile at one much in the same way that mister carker used to flash his teeth and mister john collier gives his sitter a cheerful slap on the back before he says like a shampooer in a turkish bath next man +it is obviously unnecessary for us to point out how luminous these criticisms are how delicate in expression +on the general principles of art mister quilter writes with equal lucidity +painting he tells us is of a different quality to mathematics and finish in art is adding more fact +as for etchings they are of two kinds british and foreign +he laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes the customary appeal to the last judgment and reminds us that in the great days of art michael angelo was the furnishing upholsterer + +==> data/lang_char/train_unigram100_units.txt <== + 0 + 1 +' 2 +a 3 +all 4 +and 5 +ar 6 +ason 7 +at 8 +b 9 + +==> data/lang_char/train_unigram100.vocab <== + 0 + 0 + 0 +▁ -2.01742 +e -2.7203 +s -2.82989 +t -2.99689 +l -3.53267 +n -3.84935 +o -3.88229 +``` diff --git a/examples/ted_en_zh/t0/conf/transformer.yaml b/examples/ted_en_zh/t0/conf/transformer.yaml index 755e04461..1aad86d22 100644 --- a/examples/ted_en_zh/t0/conf/transformer.yaml +++ b/examples/ted_en_zh/t0/conf/transformer.yaml @@ -68,6 +68,8 @@ model: model_conf: asr_weight: 0.0 ctc_weight: 0.0 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml b/examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml index bc1f8890d..0144c40d4 100644 --- a/examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml +++ b/examples/ted_en_zh/t0/conf/transformer_joint_noam.yaml @@ -68,6 +68,8 @@ model: model_conf: asr_weight: 0.5 ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/ted_en_zh/t0/local/train.sh b/examples/ted_en_zh/t0/local/train.sh index ec17054ab..928356f96 100755 --- a/examples/ted_en_zh/t0/local/train.sh +++ b/examples/ted_en_zh/t0/local/train.sh @@ -19,8 +19,8 @@ echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -31,7 +31,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/ted_en_zh/t0/run.sh b/examples/ted_en_zh/t0/run.sh index 26fadb608..7508f0e8a 100755 --- a/examples/ted_en_zh/t0/run.sh +++ b/examples/ted_en_zh/t0/run.sh @@ -26,7 +26,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ../../utils/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/text_normalization/README.md b/examples/text_normalization/README.md deleted file mode 100644 index dde0a5576..000000000 --- a/examples/text_normalization/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Regular expression based text normalization for Chinese - -For simplicity and ease of implementation, text normalization is basically done by rules and dictionaries. Here's an example. diff --git a/examples/timit/s1/conf/transformer.yaml b/examples/timit/s1/conf/transformer.yaml index eb191d0b2..c3b519968 100644 --- a/examples/timit/s1/conf/transformer.yaml +++ b/examples/timit/s1/conf/transformer.yaml @@ -66,6 +66,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/timit/s1/local/train.sh b/examples/timit/s1/local/train.sh index ec17054ab..3e2e4522d 100755 --- a/examples/timit/s1/local/train.sh +++ b/examples/timit/s1/local/train.sh @@ -19,8 +19,8 @@ echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -31,7 +31,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/timit/s1/run.sh b/examples/timit/s1/run.sh index 67ce78377..75a2e0c52 100755 --- a/examples/timit/s1/run.sh +++ b/examples/timit/s1/run.sh @@ -26,7 +26,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index ab9a00d92..64598b4be 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -4,7 +4,7 @@ data: dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny min_input_len: 0.0 - max_input_len: 27.0 + max_input_len: 30.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 @@ -41,9 +41,12 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 10 + accum_grad: 1 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml index 333c2b9a9..0098a226c 100644 --- a/examples/tiny/s0/conf/deepspeech2_online.yaml +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -4,7 +4,7 @@ data: dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny min_input_len: 0.0 - max_input_len: 27.0 + max_input_len: 30.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 @@ -43,9 +43,12 @@ model: num_fc_layers: 2 fc_layers_size_list: 512, 256 use_gru: True + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 10 + accum_grad: 1 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 diff --git a/examples/tiny/s0/local/train.sh b/examples/tiny/s0/local/train.sh index d42e51fac..bf4766ee3 100755 --- a/examples/tiny/s0/local/train.sh +++ b/examples/tiny/s0/local/train.sh @@ -19,8 +19,8 @@ fi mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -32,7 +32,7 @@ python3 -u ${BIN_DIR}/train.py \ --model_type ${model_type} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/tiny/s0/run.sh b/examples/tiny/s0/run.sh index 408b28fd0..f39fb3fa0 100755 --- a/examples/tiny/s0/run.sh +++ b/examples/tiny/s0/run.sh @@ -27,7 +27,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 1b701aa26..be2e82f9e 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -4,7 +4,7 @@ data: dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny min_input_len: 0.5 # second - max_input_len: 20.0 # second + max_input_len: 30.0 # second min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index 1adb91c46..93439a857 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -69,6 +69,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index b40e77e37..9bb67c44e 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -72,6 +72,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index fd5adbdee..fcbe1da4a 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -66,6 +66,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false @@ -84,7 +86,7 @@ training: lr_decay: 1.0 log_interval: 1 checkpoint: - kbest_n: 10 + kbest_n: 2 latest_n: 1 diff --git a/examples/tiny/s1/local/train.sh b/examples/tiny/s1/local/train.sh index 2fb3a95ab..48968f63c 100755 --- a/examples/tiny/s1/local/train.sh +++ b/examples/tiny/s1/local/train.sh @@ -18,8 +18,8 @@ fi mkdir -p exp -seed=1024 -if [ ${seed} ]; then +seed=10086 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -30,7 +30,7 @@ python3 -u ${BIN_DIR}/train.py \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index 41f845b05..d288e31a4 100755 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/tn/.gitignore b/examples/tn/.gitignore new file mode 100644 index 000000000..0f2503386 --- /dev/null +++ b/examples/tn/.gitignore @@ -0,0 +1 @@ +exp diff --git a/examples/tn/README.md b/examples/tn/README.md new file mode 100644 index 000000000..ff7be2934 --- /dev/null +++ b/examples/tn/README.md @@ -0,0 +1,36 @@ +# Regular expression based text normalization for Chinese + +For simplicity and ease of implementation, text normalization is basically done by rules and dictionaries. Here's an example. + +## Run + +``` +. path.sh +bash run.sh +``` + +## Results + +``` +exp/ +`-- normalized.txt + +0 directories, 1 file +``` + +``` +aff31f8aa08e2a7360228c9ce5886b98 exp/normalized.txt +``` + +``` +今天的最低气温达到零下十度. +只要有四分之三十三的人同意,就可以通过决议。 +一九四五年五月二日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。 +四月十六日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。 +如果剩下的百分之三十点六是过去,那么还有百分之六十九点四. +事情发生在二零二零年三月三十一日的上午八点. +警方正在找一支点二二口径的手枪。 +欢迎致电中国联通,北京二零二二年冬奥会官方合作伙伴为您服务 +充值缴费请按一,查询话费及余量请按二,跳过本次提醒请按井号键。 +快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按九,查询话费、套餐余量、积分及活动返款请按一,手机上网流量开通及取消请按二,查���本机号码及本号所使用套餐请按四,密码修改及重置请按五,紧急开机请按六,挂失请按七,查询充值记录请按八,其它自助服务及工服务请按零 +``` diff --git a/examples/text_normalization/data/sentences.txt b/examples/tn/data/sentences.txt similarity index 100% rename from examples/text_normalization/data/sentences.txt rename to examples/tn/data/sentences.txt diff --git a/examples/text_normalization/local/test_normalization.py b/examples/tn/local/test_normalization.py similarity index 100% rename from examples/text_normalization/local/test_normalization.py rename to examples/tn/local/test_normalization.py diff --git a/examples/text_normalization/path.sh b/examples/tn/path.sh similarity index 100% rename from examples/text_normalization/path.sh rename to examples/tn/path.sh diff --git a/examples/text_normalization/run.sh b/examples/tn/run.sh similarity index 100% rename from examples/text_normalization/run.sh rename to examples/tn/run.sh diff --git a/requirements.txt b/requirements.txt index 7c3da37e1..ebf879b51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ coverage gpustat jsonlines +jsonlines kaldiio +loguru Pillow pre-commit pybind11 @@ -14,5 +16,7 @@ SoundFile==0.9.0.post1 sox tensorboardX textgrid +tqdm typeguard +visualdl==2.2.0 yacs