# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List import numpy as np import paddle from deepspeech.utils.log import Log logger = Log(__name__).getlog() __all__ = ["forced_align", "remove_duplicates_and_blank", "insert_blank"] def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]: """ctc alignment to ctc label ids. "abaa-acee-" -> "abaace" Args: hyp (List[int]): hypotheses ids, (L) blank_id (int, optional): blank id. Defaults to 0. Returns: List[int]: remove dupicate ids, then remove blank id. """ new_hyp: List[int] = [] cur = 0 while cur < len(hyp): # add non-blank into new_hyp if hyp[cur] != blank_id: new_hyp.append(hyp[cur]) # skip repeat label prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return new_hyp def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray: """Insert blank token between every two label token. "abcdefg" -> "-a-b-c-d-e-f-g-" Args: label ([np.ndarray]): label ids, List[int], (L). blank_id (int, optional): blank id. Defaults to 0. Returns: [np.ndarray]: (2L+1). """ label = np.expand_dims(label, 1) #[L, 1] blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id label = np.concatenate([blanks, label], axis=1) #[L, 2] label = label.reshape(-1) #[2L], -l-l-l label = np.append(label, label[0]) #[2L + 1], -l-l-l- return label def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, blank_id=0) -> List[int]: """ctc forced alignment. https://distill.pub/2017/ctc/ Args: ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D) y (paddle.Tensor): label id sequence tensor, 1d tensor (L) blank_id (int): blank symbol index Returns: List[int]: best alignment result, (T). """ y_insert_blank = insert_blank(y, blank_id) #(2L+1) log_alpha = paddle.zeros( (ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16 state_path = (paddle.zeros( (ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1 ) # state path, Tuple((T, 2L+1)) # init start state log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb for t in range(1, ctc_probs.shape[0]): # T for s in range(len(y_insert_blank)): # 2L+1 if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ s] == y_insert_blank[s - 2]: candidates = paddle.to_tensor( [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) prev_state = [s, s - 1] else: candidates = paddle.to_tensor([ log_alpha[t - 1, s], log_alpha[t - 1, s - 1], log_alpha[t - 1, s - 2], ]) prev_state = [s, s - 1, s - 2] log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ y_insert_blank[s]] state_path[t, s] = prev_state[paddle.argmax(candidates)] # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16 state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32) candidates = paddle.to_tensor([ log_alpha[-1, len(y_insert_blank) - 1], # Sb log_alpha[-1, len(y_insert_blank) - 2] # Snb ]) prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] state_seq[-1] = prev_state[paddle.argmax(candidates)] for t in range(ctc_probs.shape[0] - 2, -1, -1): state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] output_alignment = [] for t in range(0, ctc_probs.shape[0]): output_alignment.append(y_insert_blank[state_seq[t, 0]]) return output_alignment