refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils

pull/578/head
Hui Zhang 5 years ago
parent fcd91c62d0
commit b31a1f46d9

@ -509,7 +509,7 @@
" print(audio_len.shape)\n", " print(audio_len.shape)\n",
" \n", " \n",
" #eouts, eouts_len = model.encoder(audio, audio_len)\n", " #eouts, eouts_len = model.encoder(audio, audio_len)\n",
" #probs = model.decoder.probs(eouts)\n", " #probs = model.decoder.softmax(eouts)\n",
" probs = model.forward(audio, audio_len)\n", " probs = model.forward(audio, audio_len)\n",
" print('paddle:', probs.numpy())\n", " print('paddle:', probs.numpy())\n",
" \n", " \n",
@ -666,4 +666,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
} }

@ -109,7 +109,7 @@ def tune(config, args):
# model infer # model infer
eouts, eouts_len = model.encoder(audio, audio_len) eouts, eouts_len = model.encoder(audio, audio_len)
probs = model.decoder.probs(eouts) probs = model.decoder.softmax(eouts)
# grid search # grid search
for index, (alpha, beta) in enumerate(params_grid): for index, (alpha, beta) in enumerate(params_grid):

@ -203,7 +203,7 @@ class DeepSpeech2Model(nn.Layer):
decoding_method=decoding_method) decoding_method=decoding_method)
eouts, eouts_len = self.encoder(audio, audio_len) eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.probs(eouts) probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs( return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method, probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
@ -234,7 +234,9 @@ class DeepSpeech2Model(nn.Layer):
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights)
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path) infos = checkpoint.load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model) layer_tools.summary(model)
return model return model
@ -268,5 +270,5 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
probs: probs after softmax probs: probs after softmax
""" """
eouts, eouts_len = self.encoder(audio, audio_len) eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.probs(eouts) probs = self.decoder.softmax(eouts)
return probs return probs

@ -20,10 +20,12 @@ from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
from deepspeech.modules.loss import CTCLoss
from deepspeech.utils import ctc_utils
from deepspeech.decoders.swig_wrapper import Scorer from deepspeech.decoders.swig_wrapper import Scorer
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch
from deepspeech.modules.loss import CTCLoss
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,38 +69,31 @@ class CTCDecoder(nn.Layer):
ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax) ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax)
ys_lens (Tensor): batch of lengths of character sequence (B) ys_lens (Tensor): batch of lengths of character sequence (B)
Returns: Returns:
loss (Tenosr): scalar. loss (Tenosr): ctc loss value, scalar.
""" """
logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
loss = self.criterion(logits, ys_pad, hlens, ys_lens) loss = self.criterion(logits, ys_pad, hlens, ys_lens)
return loss return loss
def probs(self, eouts: paddle.Tensor, temperature: float=1.0): def softmax(self, eouts: paddle.Tensor, temperature: float=1.0):
"""Get CTC probabilities. """Get CTC probabilities.
Args: Args:
eouts (FloatTensor): `[B, T, enc_units]` eouts (FloatTensor): `[B, T, enc_units]`
Returns: Returns:
probs (FloatTensor): `[B, T, odim]` probs (FloatTensor): `[B, T, odim]`
""" """
return F.softmax(self.ctc_lo(eouts) / temperature, axis=-1) self.probs = F.softmax(self.ctc_lo(eouts) / temperature, axis=2)
return self.probs
def scores(self, eouts: paddle.Tensor, temperature: float=1.0):
"""Get log-scale CTC probabilities.
Args:
eouts (FloatTensor): `[B, T, enc_units]`
Returns:
log_probs (FloatTensor): `[B, T, odim]`
"""
return F.log_softmax(self.ctc_lo(eouts) / temperature, axis=-1)
def log_softmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor: def log_softmax(self, hs_pad: paddle.Tensor,
temperature: float=1.0) -> paddle.Tensor:
"""log_softmax of frame activations """log_softmax of frame activations
Args: Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs) Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns: Returns:
paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim) paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
""" """
return self.scores(hs_pad) return F.log_softmax(self.ctc_lo(hs_pad) / temperature, axis=2)
def argmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor: def argmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor:
"""argmax of frame activations """argmax of frame activations
@ -109,6 +104,20 @@ class CTCDecoder(nn.Layer):
""" """
return paddle.argmax(self.ctc_lo(hs_pad), dim=2) return paddle.argmax(self.ctc_lo(hs_pad), dim=2)
def forced_align(self,
ctc_probs: paddle.Tensor,
y: paddle.Tensor,
blank_id=0) -> list:
"""ctc forced alignment.
Args:
ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
"""
return ctc_utils.forced_align(ctc_probs, y, blank_id)
def _decode_batch_greedy(self, probs_split, vocab_list): def _decode_batch_greedy(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input. """Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists :param probs_split: List of 2-D probability matrix, and each consists

@ -0,0 +1,56 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
from paddle.optimizer.lr import LRScheduler
logger = logging.getLogger(__name__)
__all__ = ["WarmupLR"]
class WarmupLR(LRScheduler):
"""The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following
difference:
NoamLR:
lr = optimizer.lr * model_size ** -0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
def __init__(self,
warmup_steps: Union[int, float]=25000,
learning_rate=1.0,
last_epoch=-1,
verbose=False):
assert check_argument_types()
self.warmup_steps = warmup_steps
super().__init__(learning_rate, last_epoch, verbose)
def __repr__(self):
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
def get_lr(self):
step_num = self.last_epoch + 1
return self.base_lr * self.warmup_steps**0.5 * min(
step_num**-0.5, step_num * self.warmup_steps**-1.5)
def set_step(self, step: int):
self.last_epoch = step

@ -131,8 +131,13 @@ class Trainer():
def save(self): def save(self):
"""Save checkpoint (model parameters and optimizer states). """Save checkpoint (model parameters and optimizer states).
""" """
infos = {
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr(),
}
checkpoint.save_parameters(self.checkpoint_dir, self.iteration, checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.model, self.optimizer) self.model, self.optimizer, infos)
def resume_or_load(self): def resume_or_load(self):
"""Resume from latest checkpoint at checkpoints in the output """Resume from latest checkpoint at checkpoints in the output
@ -141,12 +146,13 @@ class Trainer():
If ``args.checkpoint_path`` is not None, load the checkpoint, else If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training. resume training.
""" """
iteration = checkpoint.load_parameters( infos = checkpoint.load_parameters(
self.model, self.model,
self.optimizer, self.optimizer,
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path) checkpoint_path=self.args.checkpoint_path)
self.iteration = iteration self.iteration = infos["step"]
self.epoch = infos["epoch"]
def new_epoch(self): def new_epoch(self):
"""Reset the train loader and increment ``epoch``. """Reset the train loader and increment ``epoch``.

@ -16,6 +16,8 @@ import os
import time import time
import logging import logging
import numpy as np import numpy as np
import re
import json
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -37,15 +39,13 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
int: the latest iteration number. int: the latest iteration number.
""" """
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if (not os.path.isfile(checkpoint_record)): if not os.path.isfile(checkpoint_record):
return 0 return 0
# Fetch the latest checkpoint index. # Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle: with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip() latest_checkpoint = handle.readlines()[-1].strip()
step = latest_checkpoint.split(":")[-1] iteration = int(latest_checkpoint.split(":")[-1])
iteration = int(step.split("-")[-1])
return iteration return iteration
@ -60,7 +60,7 @@ def _save_checkpoint(checkpoint_dir: str, iteration: int):
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index. # Update the latest checkpoint index.
with open(checkpoint_record, "a+") as handle: with open(checkpoint_record, "a+") as handle:
handle.write("model_checkpoint_path:step-{}\n".format(iteration)) handle.write("model_checkpoint_path:{}\n".format(iteration))
def load_parameters(model, def load_parameters(model,
@ -74,20 +74,16 @@ def load_parameters(model,
Defaults to None. Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved. checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path and the argument 'checkpoint_dir' will stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None. be ignored. Defaults to None.
Returns: Returns:
iteration (int): number of iterations that the loaded checkpoint has configs (dict): epoch or step, lr and other meta info should be saved.
been trained.
""" """
if checkpoint_path is not None: if checkpoint_path is not None:
iteration = int(os.path.basename(checkpoint_path).split("-")[-1]) iteration = int(os.path.basename(checkpoint_path).split(":")[-1])
elif checkpoint_dir is not None: elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir) iteration = _load_latest_checkpoint(checkpoint_dir)
if iteration == 0: checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration))
return iteration
checkpoint_path = os.path.join(checkpoint_dir,
"step-{}".format(iteration))
else: else:
raise ValueError( raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
@ -98,43 +94,58 @@ def load_parameters(model,
params_path = checkpoint_path + ".pdparams" params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path) model_dict = paddle.load(params_path)
model.set_state_dict(model_dict) model.set_state_dict(model_dict)
logger.info( logger.info("Rank {}: loaded model from {}".format(rank, params_path))
"[checkpoint] Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt" optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path): if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path) optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict) optimizer.set_state_dict(optimizer_dict)
logger.info("[checkpoint] Rank {}: loaded optimizer state from {}". logger.info("Rank {}: loaded optimizer state from {}".format(
format(rank, optimizer_path)) rank, optimizer_path))
return iteration info_path = re.sub('.pdparams$', '.json', params_path)
configs = {}
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def save_parameters(checkpoint_dir, iteration, model, optimizer=None): def save_parameters(checkpoint_dir: str,
iteration: int,
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None):
"""Checkpoint the latest trained model parameters. """Checkpoint the latest trained model parameters.
Args: Args:
checkpoint_dir (str): the directory where checkpoint is saved. checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number. iteration (int): the latest iteration(step or epoch) number.
model (Layer): model to be checkpointed. model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed. optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None. Defaults to None.
infos (dict or None): any info you want to save.
Returns: Returns:
None None
""" """
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration)) checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration))
model_dict = model.state_dict() model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams" params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path) paddle.save(model_dict, params_path)
logger.info("[checkpoint] Saved model to {}".format(params_path)) logger.info("Saved model to {}".format(params_path))
if optimizer: if optimizer:
opt_dict = optimizer.state_dict() opt_dict = optimizer.state_dict()
optimizer_path = checkpoint_path + ".pdopt" optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path) paddle.save(opt_dict, optimizer_path)
logger.info( logger.info("Saved optimzier state to {}".format(optimizer_path))
"[checkpoint] Saved optimzier state to {}".format(optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if infos is None:
infos = {}
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)
_save_checkpoint(checkpoint_dir, iteration) _save_checkpoint(checkpoint_dir, iteration)

@ -0,0 +1,93 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import logging
import numpy as np
logger = logging.getLogger(__name__)
__all__ = ['load_cmvn']
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logger.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file, is_json):
if is_json:
cmvn = _load_json_cmvn(cmvn_file)
else:
cmvn = _load_kaldi_cmvn(cmvn_file)
return cmvn[0], cmvn[1]

@ -0,0 +1,128 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from typing import List
import paddle
logger = logging.getLogger(__name__)
__all__ = ["forced_align", "remove_duplicates_and_blank", "insert_blank"]
def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
"""ctc alignment to ctc label ids.
"abaa-acee-" -> "abaace"
Args:
hyp (List[int]): hypotheses ids, (L)
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[int]: remove dupicate ids, then remove blank id.
"""
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def insert_blank(label: np.ndarray, blank_id: int=0):
"""Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-"
Args:
label ([np.ndarray]): label ids, (L).
blank_id (int, optional): blank id. Defaults to 0.
Returns:
[np.ndarray]: (2L+1).
"""
label = np.expand_dims(label, 1) #[L, 1]
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1) #[L, 2]
label = label.reshape(-1) #[2L]
label = np.append(label, label[0]) #[2L + 1]
return label
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
blank_id=0) -> list:
"""ctc forced alignment.
https://distill.pub/2017/ctc/
Args:
ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
"""
y_insert_blank = insert_blank(y, blank_id)
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1
) # state path
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb
for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = paddle.to_tensor(
[log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
prev_state = [s, s - 1]
else:
candidates = paddle.to_tensor([
log_alpha[t - 1, s],
log_alpha[t - 1, s - 1],
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
y_insert_blank[s]]
state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16)
candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb
log_alpha[-1, len(y_insert_blank) - 2] # Snb
])
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = prev_state[paddle.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_alignment = []
for t in range(0, ctc_probs.size(0)):
output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment

@ -1,43 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import logging
from typing import Tuple, List
import paddle
logger = logging.getLogger(__name__)
__all__ = ["th_accuracy"]
def th_accuracy(pad_outputs: paddle.Tensor,
pad_targets: paddle.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = paddle.sum(mask)
return float(numerator) / float(denominator)

@ -20,7 +20,7 @@ import paddle
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["pad_list", "add_sos_eos", "remove_duplicates_and_blank", "log_add"] __all__ = ["pad_list", "add_sos_eos", "th_accuracy"]
IGNORE_ID = -1 IGNORE_ID = -1
@ -90,24 +90,21 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: def th_accuracy(pad_outputs: paddle.Tensor,
new_hyp: List[int] = [] pad_targets: paddle.Tensor,
cur = 0 ignore_label: int) -> float:
while cur < len(hyp): """Calculate accuracy.
if hyp[cur] != 0: Args:
new_hyp.append(hyp[cur]) pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
prev = cur pad_targets (LongTensor): Target label tensors (B, Lmax, D).
while cur < len(hyp) and hyp[cur] == hyp[prev]: ignore_label (int): Ignore label id.
cur += 1 Returns:
return new_hyp float: Accuracy value (0.0 - 1.0).
def log_add(args: List[int]) -> float:
"""
Stable log add
""" """
if all(a == -float('inf') for a in args): pad_pred = pad_outputs.view(
return -float('inf') pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
a_max = max(args) mask = pad_targets != ignore_label
lsp = math.log(sum(math.exp(a - a_max) for a in args)) numerator = paddle.sum(
return a_max + lsp pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = paddle.sum(mask)
return float(numerator) / float(denominator)

@ -13,10 +13,13 @@
# limitations under the License. # limitations under the License.
"""Contains common utility functions.""" """Contains common utility functions."""
import math
import numpy as np import numpy as np
import distutils.util import distutils.util
__all__ = ['print_arguments', 'add_arguments'] __all__ = [
'print_arguments', 'add_arguments', "log_add", "remove_duplicates_and_blank"
]
def print_arguments(args): def print_arguments(args):
@ -57,4 +60,38 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
default=default, default=default,
type=type, type=type,
help=help + ' Default: %(default)s.', help=help + ' Default: %(default)s.',
**kwargs) **kwargs)
def log_add(args: List[int]) -> float:
"""
Stable log add
"""
if all(a == -float('inf') for a in args):
return -float('inf')
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
"""ctc alignment to ctc label ids.
"abaa-acee-" -> "abaace"
Args:
hyp (List[int]): hypotheses ids, (L)
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[int]: remove dupicate ids, then remove blank id.
"""
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp

Loading…
Cancel
Save