From caaa44e368f1e453b969ed96f7e7bc228cf0b624 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 16 Sep 2021 06:10:27 +0000 Subject: [PATCH] varbase getitem support np.longlong since paddle 2.2.0RC --- deepspeech/utils/ctc_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 09543d48d..6201233df 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -86,15 +86,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, log_alpha = paddle.zeros( (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero - # TODO(Hui Zhang): zeros not support paddle.int16 state_path = (paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1 + (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 ) # state path, Tuple((T, 2L+1)) # init start state - # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 - log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb - log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb + log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb + log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb for t in range(1, ctc_probs.size(0)): # T for s in range(len(y_insert_blank)): # 2L+1 @@ -110,13 +108,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, log_alpha[t - 1, s - 2], ]) prev_state = [s, s - 1, s - 2] - # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 - log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int( - y_insert_blank[s])] + log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ + y_insert_blank[s]] state_path[t, s] = prev_state[paddle.argmax(candidates)] - # TODO(Hui Zhang): zeros not support paddle.int16 - state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32) + state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) candidates = paddle.to_tensor([ log_alpha[-1, len(y_insert_blank) - 1], # Sb