From 4a3768ad185dd8998f09c1f58a5293712a398bee Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 2 Mar 2021 10:02:07 +0000 Subject: [PATCH] refactor model add export jit model --- {notebook => .notebook}/dataloader.ipynb | 0 {notebook => .notebook}/train_test.ipynb | 0 deepspeech/exps/deepspeech2/bin/export.py | 58 +++++ deepspeech/exps/deepspeech2/bin/tune.py | 29 +-- deepspeech/exps/deepspeech2/model.py | 175 +++++++------ deepspeech/models/deepspeech2.py | 296 +++++++++++++++------- deepspeech/modules/activation.py | 6 +- deepspeech/{training => modules}/loss.py | 0 deepspeech/training/cli.py | 5 +- deepspeech/utils/utility.py | 1 + deploy/_init_paths.py | 29 --- examples/tiny/local/export.sh | 20 ++ examples/tiny/local/infer_golden.sh | 46 ---- examples/tiny/local/test_golden.sh | 47 ---- 14 files changed, 399 insertions(+), 313 deletions(-) rename {notebook => .notebook}/dataloader.ipynb (100%) rename {notebook => .notebook}/train_test.ipynb (100%) create mode 100644 deepspeech/exps/deepspeech2/bin/export.py rename deepspeech/{training => modules}/loss.py (100%) delete mode 100644 deploy/_init_paths.py create mode 100644 examples/tiny/local/export.sh delete mode 100644 examples/tiny/local/infer_golden.sh delete mode 100644 examples/tiny/local/test_golden.sh diff --git a/notebook/dataloader.ipynb b/.notebook/dataloader.ipynb similarity index 100% rename from notebook/dataloader.ipynb rename to .notebook/dataloader.ipynb diff --git a/notebook/train_test.ipynb b/.notebook/train_test.ipynb similarity index 100% rename from notebook/train_test.ipynb rename to .notebook/train_test.ipynb diff --git a/deepspeech/exps/deepspeech2/bin/export.py b/deepspeech/exps/deepspeech2/bin/export.py new file mode 100644 index 000000000..f19060ef0 --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/export.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export for DeepSpeech2 model.""" + +import io +import logging +import argparse +import functools + +from paddle import distributed as dist + +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments +from deepspeech.utils.error_rate import char_errors, word_errors + +from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_export() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 33ecfe926..eb6bddd9a 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -27,13 +27,10 @@ from deepspeech.training.cli import default_argument_parser from deepspeech.utils.error_rate import char_errors, word_errors from deepspeech.utils.utility import add_arguments, print_arguments -from deepspeech.models.network import DeepSpeech2 -from deepspeech.models.network import DeepSpeech2Loss +from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset -from deepspeech.exps.deepspeech2.dataset import SpeechCollator -from deepspeech.exps.deepspeech2.dataset import DeepSpeech2Dataset -from deepspeech.exps.deepspeech2.dataset import DeepSpeech2DistributedBatchSampler -from deepspeech.exps.deepspeech2.dataset import DeepSpeech2BatchSampler from deepspeech.exps.deepspeech2.config import get_cfg_defaults @@ -44,7 +41,7 @@ def tune(config, args): if not args.num_betas >= 0: raise ValueError("num_betas must be non-negative!") - dev_dataset = DeepSpeech2Dataset( + dev_dataset = ManifestDataset( config.data.dev_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, @@ -69,7 +66,7 @@ def tune(config, args): drop_last=False, collate_fn=SpeechCollator(is_training=False)) - model = DeepSpeech2( + model = DeepSpeech2Model( feat_size=valid_loader.dataset.feature_size, dict_size=valid_loader.dataset.vocab_size, num_conv_layers=config.model.num_conv_layers, @@ -94,9 +91,9 @@ def tune(config, args): num_ins, len_refs, cur_batch = 0, 0, 0 # initialize external scorer - model.init_decode(args.alpha_from, args.beta_from, - config.decoding.lang_model_path, vocab_list, - config.decoding.decoding_method) + model.decoder.init_decode(args.alpha_from, args.beta_from, + config.decoding.lang_model_path, vocab_list, + config.decoding.decoding_method) ## incremental tuning parameters over multiple batches print("start tuning ...") for infer_data in valid_loader(): @@ -113,15 +110,17 @@ def tune(config, args): return trans audio, text, audio_len, text_len = infer_data - _, probs, logits_lens = model.predict(audio, audio_len) target_transcripts = ordid2token(text, text_len) num_ins += audio.shape[0] + eouts, eouts_len = model.encoder(audio, audio_len) + probs = model.decoder.probs(eouts) + # grid search for index, (alpha, beta) in enumerate(params_grid): print(f"tuneing: alpha={alpha} beta={beta}") - result_transcripts = model.decode_probs( - probs.numpy(), logits_lens, vocab_list, + result_transcripts = model.decoder.decode_probs( + probs.numpy(), eouts_len, vocab_list, config.decoding.decoding_method, config.decoding.lang_model_path, alpha, beta, config.decoding.beam_size, config.decoding.cutoff_prob, @@ -165,7 +164,7 @@ def tune(config, args): (cur_batch, "%.3f" % params_grid[min_index][0], "%.3f" % params_grid[min_index][1])) - ds2_model.logger.info("finish tuning") + print("finish tuning") def main_sp(config, args): diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 32d4387b3..eb34c43e5 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -43,8 +43,9 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.dataset import ManifestDataset -from deepspeech.training.loss import CTCLoss +from deepspeech.modules.loss import CTCLoss from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.deepspeech2 import DeepSpeech2InferModel logger = logging.getLogger(__name__) @@ -53,19 +54,10 @@ class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - def compute_losses(self, inputs, outputs): - _, texts, _, texts_len = inputs - logits, _, logits_len = outputs - loss = self.criterion(logits, texts, logits_len, texts_len) - return loss - def train_batch(self, batch_data): start = time.time() self.model.train() - - outputs = self.model(*batch_data) - loss = self.compute_losses(batch_data, outputs) - + loss = self.model(*batch_data) loss.backward() print_grads(self.model, logger=None) self.optimizer.step() @@ -99,10 +91,7 @@ class DeepSpeech2Trainer(Trainer): self.model.eval() valid_losses = defaultdict(list) for i, batch in enumerate(self.valid_loader): - audio, text, audio_len, text_len = batch - outputs = self.model(*batch) - loss = self.compute_losses(batch, outputs) - #metrics = self.compute_metrics(batch, outputs) + loss = self.model(*batch) valid_losses['val_loss'].append(float(loss)) valid_losses['val_loss_div_batchsize'].append( @@ -152,13 +141,10 @@ class DeepSpeech2Trainer(Trainer): config.training.weight_decay), grad_clip=grad_clip) - criterion = CTCLoss(self.train_loader.dataset.vocab_size) - self.model = model self.optimizer = optimizer self.lr_scheduler = lr_scheduler - self.criterion = criterion - self.logger.info("Setup model/optimizer/lr_scheduler/criterion!") + self.logger.info("Setup model/optimizer/lr_scheduler!") def setup_dataloader(self): config = self.config @@ -248,12 +234,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, inputs, outputs): + def compute_metrics(self, audio, texts, audio_len, texts_len): cfg = self.config.decoding - - _, texts, _, texts_len = inputs - logits, probs, logits_len = outputs - errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = char_errors if cfg.error_rate_type == 'cer' else word_errors error_rate_func = cer if cfg.error_rate_type == 'cer' else wer @@ -261,9 +243,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): vocab_list = self.test_loader.dataset.vocab_list target_transcripts = self.ordid2token(texts, texts_len) - result_transcripts = self.model.decode_probs( - probs.numpy(), - logits_len, + result_transcripts = self.model.decode( + audio, + audio_len, vocab_list, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, @@ -298,24 +280,12 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.logger.info( f"Test Total Examples: {len(self.test_loader.dataset)}") self.model.eval() - cfg = self.config - # decoders only accept string encoded in utf-8 - vocab_list = self.test_loader.dataset.vocab_list - self.model.init_decode( - beam_alpha=cfg.decoding.alpha, - beam_beta=cfg.decoding.beta, - lang_model_path=cfg.decoding.lang_model_path, - vocab_list=vocab_list, - decoding_method=cfg.decoding.decoding_method) - error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 for i, batch in enumerate(self.test_loader): - audio, text, audio_len, text_len = batch - outputs = self.model.predict(audio, audio_len) - metrics = self.compute_metrics(batch, outputs) + metrics = self.compute_metrics(*batch) errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] @@ -332,48 +302,44 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): error_rate_type, num_ins, num_ins, errors_sum / len_refs) self.logger.info(msg) - def setup_output_dir(self): - """Create a directory used for output. - """ - # output dir - if self.args.output: - output_dir = Path(self.args.output).expanduser() - output_dir.mkdir(parents=True, exist_ok=True) - else: - output_dir = Path( - self.args.checkpoint_path).expanduser().parent.parent - output_dir.mkdir(parents=True, exist_ok=True) + def export(self): + self.infer_model.eval() + feat_dim = self.test_loader.dataset.feature_size + # static_model = paddle.jit.to_static( + # self.infer_model, + # input_spec=[ + # paddle.static.InputSpec(shape=[None, feat_dim, None], dtype='float32'), # audio, [B,D,T] + # paddle.static.InputSpec(shape=[None], dtype='int64'), # audio_length, [B] + # ]) + paddle.jit.save( + self.infer_model, + self.args.export_path, + input_spec=[ + paddle.static.InputSpec( + shape=[None, feat_dim, None], + dtype='float32'), # audio, [B,D,T] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) - self.output_dir = output_dir - - def setup_logger(self): - """Initialize a text logger to log the experiment. - - Each process has its own text logger. The logging message is write to - the standard output and a text file named ``worker_n.log`` in the - output directory, where ``n`` means the rank of the process. - """ - format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' - formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') - - logger.setLevel("INFO") + def run_test(self): + self.resume_or_load() + try: + self.test() + except KeyboardInterrupt: + exit(-1) - # global logger - stdout = True - save_path = "" - logging.basicConfig( - level=logging.DEBUG if stdout else logging.INFO, - format=format, - datefmt='%Y/%m/%d %H:%M:%S', - filename=save_path if not stdout else None) - self.logger = logger + def run_export(self): + self.resume_or_load() + try: + self.export() + except KeyboardInterrupt: + exit(-1) def setup(self): """Setup the experiment. """ paddle.set_device(self.args.device) - if self.parallel: - self.init_parallel() self.setup_output_dir() self.setup_checkpointer() @@ -385,13 +351,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.iteration = 0 self.epoch = 0 - def run_test(self): - self.resume_or_load() - try: - self.test() - except KeyboardInterrupt: - exit(-1) - def setup_model(self): config = self.config model = DeepSpeech2Model( @@ -403,14 +362,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) - if self.parallel: - model = paddle.DataParallel(model) - - criterion = CTCLoss(self.test_loader.dataset.vocab_size) + infer_model = DeepSpeech2InferModel( + feat_size=self.test_loader.dataset.feature_size, + dict_size=self.test_loader.dataset.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) self.model = model - self.criterion = criterion - self.logger.info("Setup model/criterion!") + self.infer_model = infer_model + self.logger.info("Setup model!") def setup_dataloader(self): config = self.config @@ -441,3 +404,39 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): drop_last=False, collate_fn=SpeechCollator(is_training=False)) self.logger.info("Setup test Dataloader!") + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + def setup_logger(self): + """Initialize a text logger to log the experiment. + + Each process has its own text logger. The logging message is write to + the standard output and a text file named ``worker_n.log`` in the + output directory, where ``n`` means the rank of the process. + """ + format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' + formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') + + logger.setLevel("INFO") + + # global logger + stdout = True + save_path = "" + logging.basicConfig( + level=logging.DEBUG if stdout else logging.INFO, + format=format, + datefmt='%Y/%m/%d %H:%M:%S', + filename=save_path if not stdout else None) + self.logger = logger diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 31a7f5588..374d15ef4 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -33,55 +33,14 @@ from deepspeech.decoders.swig_wrapper import Scorer from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch +from deepspeech.modules.loss import CTCLoss + logger = logging.getLogger(__name__) __all__ = ['DeepSpeech2Model'] -class DeepSpeech2Model(nn.Layer): - """The DeepSpeech2 network structure. - - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable - :param audio_len: Valid sequence length data layer. - :type audio_len: Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable - :param dict_size: Dictionary size for tokenized transcription. - :type dict_size: int - :param num_conv_layers: Number of stacking convolution layers. - :type num_conv_layers: int - :param num_rnn_layers: Number of stacking RNN layers. - :type num_rnn_layers: int - :param rnn_size: RNN layer size (dimension of RNN cells). - :type rnn_size: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward direction RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: A tuple of an output unnormalized log probability layer ( - before softmax) and a ctc cost layer. - :rtype: tuple of LayerOutput - """ - - @classmethod - def params(cls, config: Optional[CfgNode]=None) -> CfgNode: - default = CfgNode( - dict( - num_conv_layers=2, #Number of stacking convolution layers. - num_rnn_layers=3, #Number of stacking RNN layers. - rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) - if config is not None: - config.merge_from_other_cfg(default) - return default - +class CRNNEncoder(nn.Layer): def __init__(self, feat_size, dict_size, @@ -91,6 +50,7 @@ class DeepSpeech2Model(nn.Layer): use_gru=False, share_rnn_weights=True): super().__init__() + self.rnn_size = rnn_size self.feat_size = feat_size # 161 for linear self.dict_size = dict_size @@ -103,49 +63,89 @@ class DeepSpeech2Model(nn.Layer): num_stacks=num_rnn_layers, use_gru=use_gru, share_rnn_weights=share_rnn_weights) - self.fc = nn.Linear(rnn_size * 2, dict_size + 1) - self.logger = logging.getLogger(__name__) - self._ext_scorer = None + @property + def output_size(self): + return self.rnn_size * 2 - def infer(self, audio, audio_len): + def forward(self, audio, audio_len): + """ + audio: shape [B, D, T] + text: shape [B, T] + audio_len: shape [B] + text_len: shape [B] + """ + """Compute Encoder outputs + + Args: + audio (Tensor): [B, D, T] + text (Tensor): [B, T] + audio_len (Tensor): [B] + text_len (Tensor): [B] + Returns: + x (Tensor): encoder outputs, [B, T, D] + x_lens (Tensor): encoder length, [B] + """ # [B, D, T] -> [B, C=1, D, T] - audio = audio.unsqueeze(1) + x = audio.unsqueeze(1) + x_lens = audio_len # convolution group - x, audio_len = self.conv(audio, audio_len) - #print('conv out', x.shape) + x, x_lens = self.conv(x, x_lens) # convert data from convolution feature map to sequence of vectors - B, C, D, T = paddle.shape(x) + #B, C, D, T = paddle.shape(x) # not work under jit x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] - x = x.reshape([B, T, C * D]) #[B, T, C*D] - #print('rnn input', x.shape) + #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit + x = x.reshape([0, 0, -1]) #[B, T, C*D] # remove padding part - x, audio_len = self.rnn(x, audio_len) #[B, T, D] - #print('rnn output', x.shape) + x, x_lens = self.rnn(x, x_lens) #[B, T, D] + return x, x_lens + - logits = self.fc(x) #[B, T, V + 1] +class CTCDecoder(nn.Layer): + def __init__(self, enc_n_units, vocab_size): + super().__init__() + self.blank_id = vocab_size + self.output = nn.Linear(enc_n_units, + vocab_size + 1) # blank id is last id + self.criterion = CTCLoss(self.blank_id) - #ctcdecoder need probs, not log_probs - probs = F.softmax(logits) + self._ext_scorer = None - return logits, probs, audio_len + def forward(self, eout, eout_lens, texts, texts_len): + """Compute CTC Loss - def forward(self, audio, text, audio_len, text_len): + Args: + eout (Tensor): + eout_lens (Tensor): + texts (Tenosr): + texts_len (Tensor): + Returns: + loss (Tenosr): [1] """ - audio: shape [B, D, T] - text: shape [B, T] - audio_len: shape [B] - text_len: shape [B] + logits = self.output(eout) + loss = self.criterion(logits, texts, eout_lens, texts_len) + return loss + + def probs(self, eouts, temperature=1.): + """Get CTC probabilities. + Args: + eouts (FloatTensor): `[B, T, enc_units]` + Returns: + probs (FloatTensor): `[B, T, vocab]` """ - return self.infer(audio, audio_len) - - @paddle.no_grad() - def predict(self, audio, audio_len): - """ Model infer """ - return self.infer(audio, audio_len) + return F.softmax(self.output(eouts) / temperature, axis=-1) + + def scores(self, eouts, temperature=1.): + """Get log-scale CTC probabilities. + Args: + eouts (FloatTensor): `[B, T, enc_units]` + Returns: + log_probs (FloatTensor): `[B, T, vocab]` + """ + return F.log_softmax(self.output(eouts) / temperature, axis=-1) def _decode_batch_greedy(self, probs_split, vocab_list): """Decode by best path for a batch of probs matrix input. @@ -184,22 +184,22 @@ class DeepSpeech2Model(nn.Layer): return if language_model_path != '': - self.logger.info("begin to initialize the external scorer " - "for decoding") + logger.info("begin to initialize the external scorer " + "for decoding") self._ext_scorer = Scorer(beam_alpha, beam_beta, language_model_path, vocab_list) lm_char_based = self._ext_scorer.is_character_based() lm_max_order = self._ext_scorer.get_max_order() lm_dict_size = self._ext_scorer.get_dict_size() - self.logger.info("language model: " - "is_character_based = %d," % lm_char_based + - " max_order = %d," % lm_max_order + - " dict_size = %d" % lm_dict_size) - self.logger.info("end initializing scorer") + logger.info("language model: " + "is_character_based = %d," % lm_char_based + + " max_order = %d," % lm_max_order + " dict_size = %d" % + lm_dict_size) + logger.info("end initializing scorer") else: self._ext_scorer = None - self.logger.info("no language model provided, " - "decoding by pure beam search without scorer.") + logger.info("no language model provided, " + "decoding by pure beam search without scorer.") def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, @@ -275,15 +275,108 @@ class DeepSpeech2Model(nn.Layer): raise ValueError(f"Not support: {decoding_method}") return result_transcripts + +class DeepSpeech2Model(nn.Layer): + """The DeepSpeech2 network structure. + + :param audio_data: Audio spectrogram data layer. + :type audio_data: Variable + :param text_data: Transcription text data layer. + :type text_data: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable + :param masks: Masks data layer to reset padding. + :type masks: Variable + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (dimension of RNN cells). + :type rnn_size: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward direction RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=3, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + use_gru=True, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + )) + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True): + super().__init__() + self.encoder = CRNNEncoder( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + assert (self.encoder.output_size == rnn_size * 2) + self.decoder = CTCDecoder( + enc_n_units=self.encoder.output_size, vocab_size=dict_size) + + def forward(self, audio, text, audio_len, text_len): + """Compute Model loss + + Args: + audio (Tenosr): [B, D, T] + text (Tensor): [B, T] + audio_len (Tensor): [B] + text_len (Tensor): [B] + + Returns: + loss (Tenosr): [1] + """ + + eouts, eouts_len = self.encoder(audio, audio_len) + loss = self.decoder(eouts, eouts_len, text, text_len) + return loss + @paddle.no_grad() def decode(self, audio, audio_len, vocab_list, decoding_method, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes): - _, probs, logits_lens = self.predict(audio, audio_len) - return self.decode_probs(probs.numpy(), logits_lens, vocab_list, - decoding_method, lang_model_path, beam_alpha, - beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes) + # init once + # decoders only accept string encoded in utf-8 + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.probs(eouts) + return self.decoder.decode_probs( + probs.numpy(), eouts_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) def from_pretrained(self, checkpoint_path): """Build a model from a pretrained model. @@ -302,3 +395,36 @@ class DeepSpeech2Model(nn.Layer): """ checkpoint.load_parameters(self, checkpoint_path=checkpoint_path) return + + +class DeepSpeech2InferModel(DeepSpeech2Model): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True): + super().__init__( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + + def forward(self, audio, audio_len): + """export model function + + Args: + audio (Tensor): [B, D, T] + audio_len (Tensor): [B] + + Returns: + probs: probs after softmax + """ + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.probs(eouts) + return probs diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 72c2a5e2b..14861fcf7 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import numpy as np import paddle from paddle import nn @@ -25,6 +26,7 @@ __all__ = ['brelu'] def brelu(x, t_min=0.0, t_max=24.0, name=None): - t_min = paddle.to_tensor(t_min) - t_max = paddle.to_tensor(t_max) + # paddle.to_tensor is dygraph_only can not work under JIT + t_min = paddle.full(shape=[1], fill_value=t_min, dtype='float32') + t_max = paddle.full(shape=[1], fill_value=t_max, dtype='float32') return x.maximum(t_min).minimum(t_max) diff --git a/deepspeech/training/loss.py b/deepspeech/modules/loss.py similarity index 100% rename from deepspeech/training/loss.py rename to deepspeech/modules/loss.py diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index 1076fe0c7..0994f71f5 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -48,12 +48,15 @@ def default_argument_parser(): # data and output parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") - parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") + # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") # load from saved checkpoint parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") + # save jit model to + parser.add_argument("--export_path", type=str, help="path of the jit model to save") + # running parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 28be4db03..7892f9150 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -13,6 +13,7 @@ # limitations under the License. """Contains common utility functions.""" +import numpy as np import distutils.util __all__ = ['print_arguments', 'add_arguments', 'print_grads', 'print_params'] diff --git a/deploy/_init_paths.py b/deploy/_init_paths.py deleted file mode 100644 index c4b28c643..000000000 --- a/deploy/_init_paths.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Set up paths for DS2""" - -import os.path -import sys - - -def add_path(path): - if path not in sys.path: - sys.path.insert(0, path) - - -this_dir = os.path.dirname(__file__) - -# Add project path to PYTHONPATH -proj_path = os.path.join(this_dir, '..') -add_path(proj_path) diff --git a/examples/tiny/local/export.sh b/examples/tiny/local/export.sh new file mode 100644 index 000000000..1b5533916 --- /dev/null +++ b/examples/tiny/local/export.sh @@ -0,0 +1,20 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: export ckpt_path jit_model_path" + exit -1 +fi + +python3 -u ${BIN_DIR}/export.py \ +--config conf/deepspeech2.yaml \ +--checkpoint_path ${1} \ +--export_path ${2} + + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/tiny/local/infer_golden.sh b/examples/tiny/local/infer_golden.sh deleted file mode 100644 index d17b4328d..000000000 --- a/examples/tiny/local/infer_golden.sh +++ /dev/null @@ -1,46 +0,0 @@ -#! /usr/bin/env bash - -# download language model -bash local/download_lm_en.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -# download well-trained model -bash local/download_model.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -# infer -CUDA_VISIBLE_DEVICES=0 \ -python3 -u ${MAIN_ROOT}/infer.py \ ---num_samples=10 \ ---beam_size=500 \ ---num_proc_bsearch=8 \ ---num_conv_layers=2 \ ---num_rnn_layers=3 \ ---rnn_layer_size=2048 \ ---alpha=2.5 \ ---beta=0.3 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---use_gru=False \ ---use_gpu=True \ ---share_rnn_weights=True \ ---infer_manifest="data/manifest.test-clean" \ ---mean_std_path="${MAIN_ROOT}/models/librispeech/mean_std.npz" \ ---vocab_path="${MAIN_ROOT}/models/librispeech/vocab.txt" \ ---model_path="${MAIN_ROOT}/models/librispeech" \ ---lang_model_path="${MAIN_ROOT}/models/lm/common_crawl_00.prune01111.trie.klm" \ ---decoding_method="ctc_beam_search" \ ---error_rate_type="wer" \ ---specgram_type="linear" - -if [ $? -ne 0 ]; then - echo "Failed in inference!" - exit 1 -fi - - -exit 0 diff --git a/examples/tiny/local/test_golden.sh b/examples/tiny/local/test_golden.sh deleted file mode 100644 index d6b1bc8e9..000000000 --- a/examples/tiny/local/test_golden.sh +++ /dev/null @@ -1,47 +0,0 @@ -#! /usr/bin/env bash - -# download language model -bash local/download_lm_en.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -# download well-trained model -bash local/download_model.sh -if [ $? -ne 0 ]; then - exit 1 -fi - - -# evaluate model -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ -python3 -u $MAIN_ROOT/test.py \ ---batch_size=128 \ ---beam_size=500 \ ---num_proc_bsearch=8 \ ---num_conv_layers=2 \ ---num_rnn_layers=3 \ ---rnn_layer_size=2048 \ ---alpha=2.5 \ ---beta=0.3 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---use_gru=False \ ---use_gpu=True \ ---share_rnn_weights=True \ ---test_manifest="data/manifest.test-clean" \ ---mean_std_path="$MAIN_ROOT/models/librispeech/mean_std.npz" \ ---vocab_path="$MAIN_ROOT/models/librispeech/vocab.txt" \ ---model_path="$MAIN_ROOT/models/librispeech" \ ---lang_model_path="$MAIN_ROOT/models/lm/common_crawl_00.prune01111.trie.klm" \ ---decoding_method="ctc_beam_search" \ ---error_rate_type="wer" \ ---specgram_type="linear" - -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 -fi - - -exit 0