refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils

pull/578/head
Hui Zhang 5 years ago
parent fcd91c62d0
commit b31a1f46d9

@ -509,7 +509,7 @@
" print(audio_len.shape)\n",
" \n",
" #eouts, eouts_len = model.encoder(audio, audio_len)\n",
" #probs = model.decoder.probs(eouts)\n",
" #probs = model.decoder.softmax(eouts)\n",
" probs = model.forward(audio, audio_len)\n",
" print('paddle:', probs.numpy())\n",
" \n",
@ -666,4 +666,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

@ -109,7 +109,7 @@ def tune(config, args):
# model infer
eouts, eouts_len = model.encoder(audio, audio_len)
probs = model.decoder.probs(eouts)
probs = model.decoder.softmax(eouts)
# grid search
for index, (alpha, beta) in enumerate(params_grid):

@ -203,7 +203,7 @@ class DeepSpeech2Model(nn.Layer):
decoding_method=decoding_method)
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.probs(eouts)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
@ -234,7 +234,9 @@ class DeepSpeech2Model(nn.Layer):
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
infos = checkpoint.load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@ -268,5 +270,5 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
probs: probs after softmax
"""
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.probs(eouts)
probs = self.decoder.softmax(eouts)
return probs

@ -20,10 +20,12 @@ from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.loss import CTCLoss
from deepspeech.utils import ctc_utils
from deepspeech.decoders.swig_wrapper import Scorer
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch
from deepspeech.modules.loss import CTCLoss
logger = logging.getLogger(__name__)
@ -67,38 +69,31 @@ class CTCDecoder(nn.Layer):
ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax)
ys_lens (Tensor): batch of lengths of character sequence (B)
Returns:
loss (Tenosr): scalar.
loss (Tenosr): ctc loss value, scalar.
"""
logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
loss = self.criterion(logits, ys_pad, hlens, ys_lens)
return loss
def probs(self, eouts: paddle.Tensor, temperature: float=1.0):
def softmax(self, eouts: paddle.Tensor, temperature: float=1.0):
"""Get CTC probabilities.
Args:
eouts (FloatTensor): `[B, T, enc_units]`
Returns:
probs (FloatTensor): `[B, T, odim]`
"""
return F.softmax(self.ctc_lo(eouts) / temperature, axis=-1)
def scores(self, eouts: paddle.Tensor, temperature: float=1.0):
"""Get log-scale CTC probabilities.
Args:
eouts (FloatTensor): `[B, T, enc_units]`
Returns:
log_probs (FloatTensor): `[B, T, odim]`
"""
return F.log_softmax(self.ctc_lo(eouts) / temperature, axis=-1)
self.probs = F.softmax(self.ctc_lo(eouts) / temperature, axis=2)
return self.probs
def log_softmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor:
def log_softmax(self, hs_pad: paddle.Tensor,
temperature: float=1.0) -> paddle.Tensor:
"""log_softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
return self.scores(hs_pad)
return F.log_softmax(self.ctc_lo(hs_pad) / temperature, axis=2)
def argmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor:
"""argmax of frame activations
@ -109,6 +104,20 @@ class CTCDecoder(nn.Layer):
"""
return paddle.argmax(self.ctc_lo(hs_pad), dim=2)
def forced_align(self,
ctc_probs: paddle.Tensor,
y: paddle.Tensor,
blank_id=0) -> list:
"""ctc forced alignment.
Args:
ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
"""
return ctc_utils.forced_align(ctc_probs, y, blank_id)
def _decode_batch_greedy(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists

@ -0,0 +1,56 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
from paddle.optimizer.lr import LRScheduler
logger = logging.getLogger(__name__)
__all__ = ["WarmupLR"]
class WarmupLR(LRScheduler):
"""The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following
difference:
NoamLR:
lr = optimizer.lr * model_size ** -0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
def __init__(self,
warmup_steps: Union[int, float]=25000,
learning_rate=1.0,
last_epoch=-1,
verbose=False):
assert check_argument_types()
self.warmup_steps = warmup_steps
super().__init__(learning_rate, last_epoch, verbose)
def __repr__(self):
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
def get_lr(self):
step_num = self.last_epoch + 1
return self.base_lr * self.warmup_steps**0.5 * min(
step_num**-0.5, step_num * self.warmup_steps**-1.5)
def set_step(self, step: int):
self.last_epoch = step

@ -131,8 +131,13 @@ class Trainer():
def save(self):
"""Save checkpoint (model parameters and optimizer states).
"""
infos = {
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr(),
}
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.model, self.optimizer)
self.model, self.optimizer, infos)
def resume_or_load(self):
"""Resume from latest checkpoint at checkpoints in the output
@ -141,12 +146,13 @@ class Trainer():
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
iteration = checkpoint.load_parameters(
infos = checkpoint.load_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
self.iteration = iteration
self.iteration = infos["step"]
self.epoch = infos["epoch"]
def new_epoch(self):
"""Reset the train loader and increment ``epoch``.

@ -16,6 +16,8 @@ import os
import time
import logging
import numpy as np
import re
import json
import paddle
from paddle import distributed as dist
@ -37,15 +39,13 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
int: the latest iteration number.
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if (not os.path.isfile(checkpoint_record)):
if not os.path.isfile(checkpoint_record):
return 0
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
step = latest_checkpoint.split(":")[-1]
iteration = int(step.split("-")[-1])
iteration = int(latest_checkpoint.split(":")[-1])
return iteration
@ -60,7 +60,7 @@ def _save_checkpoint(checkpoint_dir: str, iteration: int):
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_record, "a+") as handle:
handle.write("model_checkpoint_path:step-{}\n".format(iteration))
handle.write("model_checkpoint_path:{}\n".format(iteration))
def load_parameters(model,
@ -74,20 +74,16 @@ def load_parameters(model,
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path and the argument 'checkpoint_dir' will
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
iteration (int): number of iterations that the loaded checkpoint has
been trained.
configs (dict): epoch or step, lr and other meta info should be saved.
"""
if checkpoint_path is not None:
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
iteration = int(os.path.basename(checkpoint_path).split(":")[-1])
elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir)
if iteration == 0:
return iteration
checkpoint_path = os.path.join(checkpoint_dir,
"step-{}".format(iteration))
checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
@ -98,43 +94,58 @@ def load_parameters(model,
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info(
"[checkpoint] Rank {}: loaded model from {}".format(rank, params_path))
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("[checkpoint] Rank {}: loaded optimizer state from {}".
format(rank, optimizer_path))
logger.info("Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
return iteration
info_path = re.sub('.pdparams$', '.json', params_path)
configs = {}
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
@mp_tools.rank_zero_only
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
def save_parameters(checkpoint_dir: str,
iteration: int,
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
iteration (int): the latest iteration(step or epoch) number.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None.
infos (dict or None): any info you want to save.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration))
model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path)
logger.info("[checkpoint] Saved model to {}".format(params_path))
logger.info("Saved model to {}".format(params_path))
if optimizer:
opt_dict = optimizer.state_dict()
optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path)
logger.info(
"[checkpoint] Saved optimzier state to {}".format(optimizer_path))
logger.info("Saved optimzier state to {}".format(optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if infos is None:
infos = {}
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)
_save_checkpoint(checkpoint_dir, iteration)

@ -0,0 +1,93 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import logging
import numpy as np
logger = logging.getLogger(__name__)
__all__ = ['load_cmvn']
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logger.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file, is_json):
if is_json:
cmvn = _load_json_cmvn(cmvn_file)
else:
cmvn = _load_kaldi_cmvn(cmvn_file)
return cmvn[0], cmvn[1]

@ -0,0 +1,128 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from typing import List
import paddle
logger = logging.getLogger(__name__)
__all__ = ["forced_align", "remove_duplicates_and_blank", "insert_blank"]
def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
"""ctc alignment to ctc label ids.
"abaa-acee-" -> "abaace"
Args:
hyp (List[int]): hypotheses ids, (L)
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[int]: remove dupicate ids, then remove blank id.
"""
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def insert_blank(label: np.ndarray, blank_id: int=0):
"""Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-"
Args:
label ([np.ndarray]): label ids, (L).
blank_id (int, optional): blank id. Defaults to 0.
Returns:
[np.ndarray]: (2L+1).
"""
label = np.expand_dims(label, 1) #[L, 1]
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1) #[L, 2]
label = label.reshape(-1) #[2L]
label = np.append(label, label[0]) #[2L + 1]
return label
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
blank_id=0) -> list:
"""ctc forced alignment.
https://distill.pub/2017/ctc/
Args:
ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
"""
y_insert_blank = insert_blank(y, blank_id)
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1
) # state path
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb
for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = paddle.to_tensor(
[log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
prev_state = [s, s - 1]
else:
candidates = paddle.to_tensor([
log_alpha[t - 1, s],
log_alpha[t - 1, s - 1],
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
y_insert_blank[s]]
state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16)
candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb
log_alpha[-1, len(y_insert_blank) - 2] # Snb
])
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = prev_state[paddle.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_alignment = []
for t in range(0, ctc_probs.size(0)):
output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment

@ -1,43 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import logging
from typing import Tuple, List
import paddle
logger = logging.getLogger(__name__)
__all__ = ["th_accuracy"]
def th_accuracy(pad_outputs: paddle.Tensor,
pad_targets: paddle.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = paddle.sum(mask)
return float(numerator) / float(denominator)

@ -20,7 +20,7 @@ import paddle
logger = logging.getLogger(__name__)
__all__ = ["pad_list", "add_sos_eos", "remove_duplicates_and_blank", "log_add"]
__all__ = ["pad_list", "add_sos_eos", "th_accuracy"]
IGNORE_ID = -1
@ -90,24 +90,21 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def log_add(args: List[int]) -> float:
"""
Stable log add
def th_accuracy(pad_outputs: paddle.Tensor,
pad_targets: paddle.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
if all(a == -float('inf') for a in args):
return -float('inf')
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = paddle.sum(mask)
return float(numerator) / float(denominator)

@ -13,10 +13,13 @@
# limitations under the License.
"""Contains common utility functions."""
import math
import numpy as np
import distutils.util
__all__ = ['print_arguments', 'add_arguments']
__all__ = [
'print_arguments', 'add_arguments', "log_add", "remove_duplicates_and_blank"
]
def print_arguments(args):
@ -57,4 +60,38 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
**kwargs)
def log_add(args: List[int]) -> float:
"""
Stable log add
"""
if all(a == -float('inf') for a in args):
return -float('inf')
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
"""ctc alignment to ctc label ids.
"abaa-acee-" -> "abaace"
Args:
hyp (List[int]): hypotheses ids, (L)
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[int]: remove dupicate ids, then remove blank id.
"""
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp

Loading…
Cancel
Save