From b31a1f46d9df6f5c29f245a8a66841a6ea7ca4b1 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 31 Mar 2021 08:15:45 +0000 Subject: [PATCH] refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils --- .notebook/jit_infer.ipynb | 4 +- deepspeech/exps/deepspeech2/bin/tune.py | 2 +- deepspeech/models/deepspeech2.py | 8 +- deepspeech/modules/ctc.py | 39 ++++-- deepspeech/training/scheduler.py | 56 ++++++++ deepspeech/training/trainer.py | 12 +- deepspeech/utils/checkpoint.py | 59 ++++---- deepspeech/utils/cmvn.py | 93 +++++++++++++ deepspeech/utils/ctc_utils.py | 128 ++++++++++++++++++ deepspeech/utils/metric.py | 43 ------ .../utils/{common.py => tensor_utils.py} | 39 +++--- deepspeech/utils/utility.py | 41 +++++- 12 files changed, 410 insertions(+), 114 deletions(-) create mode 100644 deepspeech/training/scheduler.py create mode 100644 deepspeech/utils/cmvn.py create mode 100644 deepspeech/utils/ctc_utils.py delete mode 100644 deepspeech/utils/metric.py rename deepspeech/utils/{common.py => tensor_utils.py} (79%) diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb index 397c59603..a62e76a2e 100644 --- a/.notebook/jit_infer.ipynb +++ b/.notebook/jit_infer.ipynb @@ -509,7 +509,7 @@ " print(audio_len.shape)\n", " \n", " #eouts, eouts_len = model.encoder(audio, audio_len)\n", - " #probs = model.decoder.probs(eouts)\n", + " #probs = model.decoder.softmax(eouts)\n", " probs = model.forward(audio, audio_len)\n", " print('paddle:', probs.numpy())\n", " \n", @@ -666,4 +666,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 1fc8dc0c1..33b83283f 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -109,7 +109,7 @@ def tune(config, args): # model infer eouts, eouts_len = model.encoder(audio, audio_len) - probs = model.decoder.probs(eouts) + probs = model.decoder.softmax(eouts) # grid search for index, (alpha, beta) in enumerate(params_grid): diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index ffe678a69..cab1e45e1 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -203,7 +203,7 @@ class DeepSpeech2Model(nn.Layer): decoding_method=decoding_method) eouts, eouts_len = self.encoder(audio, audio_len) - probs = self.decoder.probs(eouts) + probs = self.decoder.softmax(eouts) return self.decoder.decode_probs( probs.numpy(), eouts_len, vocab_list, decoding_method, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, @@ -234,7 +234,9 @@ class DeepSpeech2Model(nn.Layer): rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) - checkpoint.load_parameters(model, checkpoint_path=checkpoint_path) + infos = checkpoint.load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") layer_tools.summary(model) return model @@ -268,5 +270,5 @@ class DeepSpeech2InferModel(DeepSpeech2Model): probs: probs after softmax """ eouts, eouts_len = self.encoder(audio, audio_len) - probs = self.decoder.probs(eouts) + probs = self.decoder.softmax(eouts) return probs diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 66737f599..f11924909 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -20,10 +20,12 @@ from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I +from deepspeech.modules.loss import CTCLoss +from deepspeech.utils import ctc_utils + from deepspeech.decoders.swig_wrapper import Scorer from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch -from deepspeech.modules.loss import CTCLoss logger = logging.getLogger(__name__) @@ -67,38 +69,31 @@ class CTCDecoder(nn.Layer): ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax) ys_lens (Tensor): batch of lengths of character sequence (B) Returns: - loss (Tenosr): scalar. + loss (Tenosr): ctc loss value, scalar. """ logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) loss = self.criterion(logits, ys_pad, hlens, ys_lens) return loss - def probs(self, eouts: paddle.Tensor, temperature: float=1.0): + def softmax(self, eouts: paddle.Tensor, temperature: float=1.0): """Get CTC probabilities. Args: eouts (FloatTensor): `[B, T, enc_units]` Returns: probs (FloatTensor): `[B, T, odim]` """ - return F.softmax(self.ctc_lo(eouts) / temperature, axis=-1) - - def scores(self, eouts: paddle.Tensor, temperature: float=1.0): - """Get log-scale CTC probabilities. - Args: - eouts (FloatTensor): `[B, T, enc_units]` - Returns: - log_probs (FloatTensor): `[B, T, odim]` - """ - return F.log_softmax(self.ctc_lo(eouts) / temperature, axis=-1) + self.probs = F.softmax(self.ctc_lo(eouts) / temperature, axis=2) + return self.probs - def log_softmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor: + def log_softmax(self, hs_pad: paddle.Tensor, + temperature: float=1.0) -> paddle.Tensor: """log_softmax of frame activations Args: Tensor hs_pad: 3d tensor (B, Tmax, eprojs) Returns: paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim) """ - return self.scores(hs_pad) + return F.log_softmax(self.ctc_lo(hs_pad) / temperature, axis=2) def argmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor: """argmax of frame activations @@ -109,6 +104,20 @@ class CTCDecoder(nn.Layer): """ return paddle.argmax(self.ctc_lo(hs_pad), dim=2) + def forced_align(self, + ctc_probs: paddle.Tensor, + y: paddle.Tensor, + blank_id=0) -> list: + """ctc forced alignment. + Args: + ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D) + y (paddle.Tensor): label id sequence tensor, 1d tensor (L) + blank_id (int): blank symbol index + Returns: + paddle.Tensor: best alignment result, (T). + """ + return ctc_utils.forced_align(ctc_probs, y, blank_id) + def _decode_batch_greedy(self, probs_split, vocab_list): """Decode by best path for a batch of probs matrix input. :param probs_split: List of 2-D probability matrix, and each consists diff --git a/deepspeech/training/scheduler.py b/deepspeech/training/scheduler.py new file mode 100644 index 000000000..8eb8096fe --- /dev/null +++ b/deepspeech/training/scheduler.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import paddle +from paddle.optimizer.lr import LRScheduler + +logger = logging.getLogger(__name__) + +__all__ = ["WarmupLR"] + + +class WarmupLR(LRScheduler): + """The WarmupLR scheduler + This scheduler is almost same as NoamLR Scheduler except for following + difference: + NoamLR: + lr = optimizer.lr * model_size ** -0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + WarmupLR: + lr = optimizer.lr * warmup_step ** 0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + Note that the maximum lr equals to optimizer.lr in this scheduler. + """ + + def __init__(self, + warmup_steps: Union[int, float]=25000, + learning_rate=1.0, + last_epoch=-1, + verbose=False): + assert check_argument_types() + self.warmup_steps = warmup_steps + super().__init__(learning_rate, last_epoch, verbose) + + def __repr__(self): + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + + def get_lr(self): + step_num = self.last_epoch + 1 + return self.base_lr * self.warmup_steps**0.5 * min( + step_num**-0.5, step_num * self.warmup_steps**-1.5) + + def set_step(self, step: int): + self.last_epoch = step diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index f472200b8..3a381b2b7 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -131,8 +131,13 @@ class Trainer(): def save(self): """Save checkpoint (model parameters and optimizer states). """ + infos = { + "step": self.iteration, + "epoch": self.epoch, + "lr": self.optimizer.get_lr(), + } checkpoint.save_parameters(self.checkpoint_dir, self.iteration, - self.model, self.optimizer) + self.model, self.optimizer, infos) def resume_or_load(self): """Resume from latest checkpoint at checkpoints in the output @@ -141,12 +146,13 @@ class Trainer(): If ``args.checkpoint_path`` is not None, load the checkpoint, else resume training. """ - iteration = checkpoint.load_parameters( + infos = checkpoint.load_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) - self.iteration = iteration + self.iteration = infos["step"] + self.epoch = infos["epoch"] def new_epoch(self): """Reset the train loader and increment ``epoch``. diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index f2066fdec..d265358b2 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -16,6 +16,8 @@ import os import time import logging import numpy as np +import re +import json import paddle from paddle import distributed as dist @@ -37,15 +39,13 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int: int: the latest iteration number. """ checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - if (not os.path.isfile(checkpoint_record)): + if not os.path.isfile(checkpoint_record): return 0 # Fetch the latest checkpoint index. with open(checkpoint_record, "rt") as handle: latest_checkpoint = handle.readlines()[-1].strip() - step = latest_checkpoint.split(":")[-1] - iteration = int(step.split("-")[-1]) - + iteration = int(latest_checkpoint.split(":")[-1]) return iteration @@ -60,7 +60,7 @@ def _save_checkpoint(checkpoint_dir: str, iteration: int): checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") # Update the latest checkpoint index. with open(checkpoint_record, "a+") as handle: - handle.write("model_checkpoint_path:step-{}\n".format(iteration)) + handle.write("model_checkpoint_path:{}\n".format(iteration)) def load_parameters(model, @@ -74,20 +74,16 @@ def load_parameters(model, Defaults to None. checkpoint_dir (str, optional): the directory where checkpoint is saved. checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path and the argument 'checkpoint_dir' will + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will be ignored. Defaults to None. Returns: - iteration (int): number of iterations that the loaded checkpoint has - been trained. + configs (dict): epoch or step, lr and other meta info should be saved. """ if checkpoint_path is not None: - iteration = int(os.path.basename(checkpoint_path).split("-")[-1]) + iteration = int(os.path.basename(checkpoint_path).split(":")[-1]) elif checkpoint_dir is not None: iteration = _load_latest_checkpoint(checkpoint_dir) - if iteration == 0: - return iteration - checkpoint_path = os.path.join(checkpoint_dir, - "step-{}".format(iteration)) + checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration)) else: raise ValueError( "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" @@ -98,43 +94,58 @@ def load_parameters(model, params_path = checkpoint_path + ".pdparams" model_dict = paddle.load(params_path) model.set_state_dict(model_dict) - logger.info( - "[checkpoint] Rank {}: loaded model from {}".format(rank, params_path)) + logger.info("Rank {}: loaded model from {}".format(rank, params_path)) optimizer_path = checkpoint_path + ".pdopt" if optimizer and os.path.isfile(optimizer_path): optimizer_dict = paddle.load(optimizer_path) optimizer.set_state_dict(optimizer_dict) - logger.info("[checkpoint] Rank {}: loaded optimizer state from {}". - format(rank, optimizer_path)) + logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) - return iteration + info_path = re.sub('.pdparams$', '.json', params_path) + configs = {} + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs @mp_tools.rank_zero_only -def save_parameters(checkpoint_dir, iteration, model, optimizer=None): +def save_parameters(checkpoint_dir: str, + iteration: int, + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): """Checkpoint the latest trained model parameters. Args: checkpoint_dir (str): the directory where checkpoint is saved. - iteration (int): the latest iteration number. + iteration (int): the latest iteration(step or epoch) number. model (Layer): model to be checkpointed. optimizer (Optimizer, optional): optimizer to be checkpointed. Defaults to None. + infos (dict or None): any info you want to save. Returns: None """ - checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration)) + checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration)) model_dict = model.state_dict() params_path = checkpoint_path + ".pdparams" paddle.save(model_dict, params_path) - logger.info("[checkpoint] Saved model to {}".format(params_path)) + logger.info("Saved model to {}".format(params_path)) if optimizer: opt_dict = optimizer.state_dict() optimizer_path = checkpoint_path + ".pdopt" paddle.save(opt_dict, optimizer_path) - logger.info( - "[checkpoint] Saved optimzier state to {}".format(optimizer_path)) + logger.info("Saved optimzier state to {}".format(optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if infos is None: + infos = {} + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) _save_checkpoint(checkpoint_dir, iteration) diff --git a/deepspeech/utils/cmvn.py b/deepspeech/utils/cmvn.py new file mode 100644 index 000000000..5c5573ee9 --- /dev/null +++ b/deepspeech/utils/cmvn.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math +import logging +import numpy as np + +logger = logging.getLogger(__name__) + +__all__ = ['load_cmvn'] + + +def _load_json_cmvn(json_cmvn_file): + """ Load the json format cmvn stats file and calculate cmvn + Args: + json_cmvn_file: cmvn stats file in json format + Returns: + a numpy array of [means, vars] + """ + with open(json_cmvn_file) as f: + cmvn_stats = json.load(f) + + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def _load_kaldi_cmvn(kaldi_cmvn_file): + """ Load the kaldi format cmvn stats file and calculate cmvn + Args: + kaldi_cmvn_file: kaldi text style global cmvn file, which + is generated by: + compute-cmvn-stats --binary=false scp:feats.scp global_cmvn + Returns: + a numpy array of [means, vars] + """ + means = [] + variance = [] + with open(kaldi_cmvn_file, 'r') as fid: + # kaldi binary file start with '\0B' + if fid.read(2) == '\0B': + logger.error('kaldi cmvn binary file is not supported, please ' + 'recompute it by: compute-cmvn-stats --binary=false ' + ' scp:feats.scp global_cmvn') + sys.exit(1) + fid.seek(0) + arr = fid.read().split() + assert (arr[0] == '[') + assert (arr[-2] == '0') + assert (arr[-1] == ']') + feat_dim = int((len(arr) - 2 - 2) / 2) + for i in range(1, feat_dim + 1): + means.append(float(arr[i])) + count = float(arr[feat_dim + 1]) + for i in range(feat_dim + 2, 2 * feat_dim + 2): + variance.append(float(arr[i])) + + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def load_cmvn(cmvn_file, is_json): + if is_json: + cmvn = _load_json_cmvn(cmvn_file) + else: + cmvn = _load_kaldi_cmvn(cmvn_file) + return cmvn[0], cmvn[1] diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py new file mode 100644 index 000000000..9517b8d64 --- /dev/null +++ b/deepspeech/utils/ctc_utils.py @@ -0,0 +1,128 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +from typing import List + +import paddle + +logger = logging.getLogger(__name__) + +__all__ = ["forced_align", "remove_duplicates_and_blank", "insert_blank"] + + +def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]: + """ctc alignment to ctc label ids. + + "abaa-acee-" -> "abaace" + + Args: + hyp (List[int]): hypotheses ids, (L) + blank_id (int, optional): blank id. Defaults to 0. + + Returns: + List[int]: remove dupicate ids, then remove blank id. + """ + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != blank_id: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def insert_blank(label: np.ndarray, blank_id: int=0): + """Insert blank token between every two label token. + + "abcdefg" -> "-a-b-c-d-e-f-g-" + + Args: + label ([np.ndarray]): label ids, (L). + blank_id (int, optional): blank id. Defaults to 0. + + Returns: + [np.ndarray]: (2L+1). + """ + label = np.expand_dims(label, 1) #[L, 1] + blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id + label = np.concatenate([blanks, label], axis=1) #[L, 2] + label = label.reshape(-1) #[2L] + label = np.append(label, label[0]) #[2L + 1] + return label + + +def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, + blank_id=0) -> list: + """ctc forced alignment. + + https://distill.pub/2017/ctc/ + + Args: + ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D) + y (paddle.Tensor): label id sequence tensor, 1d tensor (L) + blank_id (int): blank symbol index + Returns: + paddle.Tensor: best alignment result, (T). + """ + y_insert_blank = insert_blank(y, blank_id) + + log_alpha = paddle.zeros( + (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) + log_alpha = log_alpha - float('inf') # log of zero + state_path = (paddle.zeros( + (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 + ) # state path + + # init start state + log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb + log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb + + for t in range(1, ctc_probs.size(0)): + for s in range(len(y_insert_blank)): + if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ + s] == y_insert_blank[s - 2]: + candidates = paddle.to_tensor( + [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) + prev_state = [s, s - 1] + else: + candidates = paddle.to_tensor([ + log_alpha[t - 1, s], + log_alpha[t - 1, s - 1], + log_alpha[t - 1, s - 2], + ]) + prev_state = [s, s - 1, s - 2] + log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ + y_insert_blank[s]] + state_path[t, s] = prev_state[paddle.argmax(candidates)] + + state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) + + candidates = paddle.to_tensor([ + log_alpha[-1, len(y_insert_blank) - 1], # Sb + log_alpha[-1, len(y_insert_blank) - 2] # Snb + ]) + prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] + state_seq[-1] = prev_state[paddle.argmax(candidates)] + for t in range(ctc_probs.size(0) - 2, -1, -1): + state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] + + output_alignment = [] + for t in range(0, ctc_probs.size(0)): + output_alignment.append(y_insert_blank[state_seq[t, 0]]) + + return output_alignment diff --git a/deepspeech/utils/metric.py b/deepspeech/utils/metric.py deleted file mode 100644 index e53b24056..000000000 --- a/deepspeech/utils/metric.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import logging -from typing import Tuple, List - -import paddle - -logger = logging.getLogger(__name__) - -__all__ = ["th_accuracy"] - - -def th_accuracy(pad_outputs: paddle.Tensor, - pad_targets: paddle.Tensor, - ignore_label: int) -> float: - """Calculate accuracy. - Args: - pad_outputs (Tensor): Prediction tensors (B * Lmax, D). - pad_targets (LongTensor): Target label tensors (B, Lmax, D). - ignore_label (int): Ignore label id. - Returns: - float: Accuracy value (0.0 - 1.0). - """ - pad_pred = pad_outputs.view( - pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) - mask = pad_targets != ignore_label - numerator = paddle.sum( - pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) - denominator = paddle.sum(mask) - return float(numerator) / float(denominator) diff --git a/deepspeech/utils/common.py b/deepspeech/utils/tensor_utils.py similarity index 79% rename from deepspeech/utils/common.py rename to deepspeech/utils/tensor_utils.py index b4673e2be..627f51630 100644 --- a/deepspeech/utils/common.py +++ b/deepspeech/utils/tensor_utils.py @@ -20,7 +20,7 @@ import paddle logger = logging.getLogger(__name__) -__all__ = ["pad_list", "add_sos_eos", "remove_duplicates_and_blank", "log_add"] +__all__ = ["pad_list", "add_sos_eos", "th_accuracy"] IGNORE_ID = -1 @@ -90,24 +90,21 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) -def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: - new_hyp: List[int] = [] - cur = 0 - while cur < len(hyp): - if hyp[cur] != 0: - new_hyp.append(hyp[cur]) - prev = cur - while cur < len(hyp) and hyp[cur] == hyp[prev]: - cur += 1 - return new_hyp - - -def log_add(args: List[int]) -> float: - """ - Stable log add +def th_accuracy(pad_outputs: paddle.Tensor, + pad_targets: paddle.Tensor, + ignore_label: int) -> float: + """Calculate accuracy. + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + Returns: + float: Accuracy value (0.0 - 1.0). """ - if all(a == -float('inf') for a in args): - return -float('inf') - a_max = max(args) - lsp = math.log(sum(math.exp(a - a_max) for a in args)) - return a_max + lsp \ No newline at end of file + pad_pred = pad_outputs.view( + pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = paddle.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = paddle.sum(mask) + return float(numerator) / float(denominator) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 72a45e29a..20da878b9 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -13,10 +13,13 @@ # limitations under the License. """Contains common utility functions.""" +import math import numpy as np import distutils.util -__all__ = ['print_arguments', 'add_arguments'] +__all__ = [ + 'print_arguments', 'add_arguments', "log_add", "remove_duplicates_and_blank" +] def print_arguments(args): @@ -57,4 +60,38 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): default=default, type=type, help=help + ' Default: %(default)s.', - **kwargs) \ No newline at end of file + **kwargs) + + +def log_add(args: List[int]) -> float: + """ + Stable log add + """ + if all(a == -float('inf') for a in args): + return -float('inf') + a_max = max(args) + lsp = math.log(sum(math.exp(a - a_max) for a in args)) + return a_max + lsp + + +def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]: + """ctc alignment to ctc label ids. + + "abaa-acee-" -> "abaace" + + Args: + hyp (List[int]): hypotheses ids, (L) + blank_id (int, optional): blank id. Defaults to 0. + + Returns: + List[int]: remove dupicate ids, then remove blank id. + """ + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != blank_id: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp