From 2430545d454d2e1cbb89cb74ef0b46d92c11dcde Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 13 Oct 2021 10:40:20 +0000 Subject: [PATCH] update vector ctc prefix score --- deepspeech/decoders/scores/ctc.py | 8 +- .../decoders/scores/ctc_prefix_score.py | 136 +++++++++--------- 2 files changed, 71 insertions(+), 73 deletions(-) diff --git a/deepspeech/decoders/scores/ctc.py b/deepspeech/decoders/scores/ctc.py index 6e5a4c53..aaa3dc86 100644 --- a/deepspeech/decoders/scores/ctc.py +++ b/deepspeech/decoders/scores/ctc.py @@ -4,7 +4,7 @@ import numpy as np import paddle from .ctc_prefix_score import CTCPrefixScore -from .ctc_prefix_score import CTCPrefixScoreTH +from .ctc_prefix_score import CTCPrefixScorePD from .scorer_interface import BatchPartialScorerInterface @@ -34,7 +34,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): """ logp = self.ctc.log_softmax(x.unsqueeze(0)).squeeze(0).numpy() - # TODO(karita): use CTCPrefixScoreTH + # TODO(karita): use CTCPrefixScorePD self.impl = CTCPrefixScore(logp, 0, self.eos, np) return 0, self.impl.initial_state() @@ -54,7 +54,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): if len(state) == 2: # for CTCPrefixScore sc, st = state return sc[i], st[i] - else: # for CTCPrefixScoreTH (need new_id > 0) + else: # for CTCPrefixScorePD (need new_id > 0) r, log_psi, f_min, f_max, scoring_idmap = state s = log_psi[i, new_id].expand(log_psi.size(1)) if scoring_idmap is not None: @@ -96,7 +96,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): """ logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 xlen = paddle.to_tensor([logp.size(1)]) - self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) + self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos) return None def batch_score_partial(self, y, ids, state, x): diff --git a/deepspeech/decoders/scores/ctc_prefix_score.py b/deepspeech/decoders/scores/ctc_prefix_score.py index 2ca00ebf..77ac09cd 100644 --- a/deepspeech/decoders/scores/ctc_prefix_score.py +++ b/deepspeech/decoders/scores/ctc_prefix_score.py @@ -3,13 +3,13 @@ # Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import torch +import paddle import numpy as np import six -class CTCPrefixScoreTH(): +class CTCPrefixScorePD(): """Batch processing of CTCPrefixScore which is based on Algorithm 2 in WATANABE et al. @@ -23,8 +23,10 @@ class CTCPrefixScoreTH(): def __init__(self, x, xlens, blank, eos, margin=0): """Construct CTC prefix scorer - :param torch.Tensor x: input label posterior sequences (B, T, O) - :param torch.Tensor xlens: input lengths (B,) + `margin` is M in eq.(22,23) + + :param paddle.Tensor x: input label posterior sequences (B, T, O) + :param paddle.Tensor xlens: input lengths (B,) :param int blank: blank label id :param int eos: end-of-sequence id :param int margin: margin parameter for windowing (0 means no windowing) @@ -38,11 +40,8 @@ class CTCPrefixScoreTH(): self.input_length = x.size(1) self.odim = x.size(2) self.dtype = x.dtype - self.device = ( - torch.device("cuda:%d" % x.get_device()) - if x.is_cuda - else torch.device("cpu") - ) + self.device = x.place + # Pad the rest of posteriors in the batch # TODO(takaaki-hori): need a better way without for-loops for i, l in enumerate(xlens): @@ -50,20 +49,21 @@ class CTCPrefixScoreTH(): x[i, l:, :] = self.logzero x[i, l:, blank] = 0 # Reshape input x - xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O) - xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) - self.x = torch.stack([xn, xb]) # (2, T, B, O) - self.end_frames = torch.as_tensor(xlens) - 1 + xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O) + xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) # (T,B,O) + self.x = paddle.stack([xn, xb]) # (2, T, B, O) + self.end_frames = paddle.to_tensor(xlens) - 1 # (B,) # Setup CTC windowing self.margin = margin if margin > 0: - self.frame_ids = torch.arange( - self.input_length, dtype=self.dtype, device=self.device - ) + self.frame_ids = paddle.arange(self.input_length, dtype=self.dtype) # Base indices for index conversion - self.idx_bh = None - self.idx_b = torch.arange(self.batch, device=self.device) + # B idx, hyp idx. shape (B*W, 1) + self.idx_bh = None + # B idx. shape (B,) + self.idx_b = paddle.arange(self.batch, place=self.device) + # B idx, O idx. shape (B, 1) self.idx_bo = (self.idx_b * self.odim).unsqueeze(1) def __call__(self, y, state, scoring_ids=None, att_w=None): @@ -71,8 +71,8 @@ class CTCPrefixScoreTH(): :param list y: prefix label sequences :param tuple state: previous CTC state - :param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O) - :param torch.Tensor att_w: attention weights to decide CTC window + :param paddle.Tensor scoring_ids: selected next ids to score (BW, O'), O' <= O + :param paddle.Tensor att_w: attention weights to decide CTC window :return new_state, ctc_local_scores (BW, O) """ output_length = len(y[0]) - 1 # ignore sos @@ -82,56 +82,53 @@ class CTCPrefixScoreTH(): self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0 # prepare state info if state is None: - r_prev = torch.full( + r_prev = paddle.full( (self.input_length, 2, self.batch, n_hyps), self.logzero, dtype=self.dtype, - device=self.device, - ) - r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2) - r_prev = r_prev.view(-1, 2, n_bh) - s_prev = 0.0 - f_min_prev = 0 - f_max_prev = 1 + ) # (T, 2, B, W) + r_prev[:, 1] = paddle.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2) + r_prev = r_prev.view(-1, 2, n_bh) # (T, 2, BW) + s_prev = 0.0 # score + f_min_prev = 0 # eq. 22-23 + f_max_prev = 1 # eq. 22-23 else: r_prev, s_prev, f_min_prev, f_max_prev = state # select input dimensions for scoring if self.scoring_num > 0: - scoring_idmap = torch.full( - (n_bh, self.odim), -1, dtype=torch.long, device=self.device - ) + # (BW, O) + scoring_idmap = paddle.full((n_bh, self.odim), -1, dtype=paddle.long) snum = self.scoring_num if self.idx_bh is None or n_bh > len(self.idx_bh): - self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1) - scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange( - snum, device=self.device - ) + self.idx_bh = paddle.arange(n_bh).view(-1, 1) # (BW, 1) + scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = paddle.arange(snum) scoring_idx = ( - scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) - ).view(-1) - x_ = torch.index_select( - self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx + scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) # (BW,1) + ).view(-1) # (BWO) + # x_ shape (2, T, B*W, O) + x_ = paddle.index_select( + self.x.view(2, -1, self.batch * self.odim), scoring_idx, 2 ).view(2, -1, n_bh, snum) else: scoring_ids = None scoring_idmap = None snum = self.odim + # x_ shape (2, T, B*W, O) x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum) # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor # that corresponds to r_t^n(h) and r_t^b(h) in a batch. - r = torch.full( + r = paddle.full( (self.input_length, 2, n_bh, snum), self.logzero, dtype=self.dtype, - device=self.device, ) if output_length == 0: r[0, 0] = x_[0, 0] - r_sum = torch.logsumexp(r_prev, 1) - log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) + r_sum = paddle.logsumexp(r_prev, 1) #(T,BW) + log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) # (T, BW, O) if scoring_ids is not None: for idx in range(n_bh): pos = scoring_idmap[idx, last_ids[idx]] @@ -143,40 +140,39 @@ class CTCPrefixScoreTH(): # decide start and end frames based on attention weights if att_w is not None and self.margin > 0: - f_arg = torch.matmul(att_w, self.frame_ids) + f_arg = paddle.matmul(att_w, self.frame_ids) f_min = max(int(f_arg.min().cpu()), f_min_prev) f_max = max(int(f_arg.max().cpu()), f_max_prev) start = min(f_max_prev, max(f_min - self.margin, output_length, 1)) end = min(f_max + self.margin, self.input_length) else: f_min = f_max = 0 + # if one frame one out, the output_length is the eating frame num now. start = max(output_length, 1) end = self.input_length # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) for t in range(start, end): - rp = r[t - 1] - rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( + rp = r[t - 1] # (2 x BW x O') + rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( 2, 2, n_bh, snum - ) - r[t] = torch.logsumexp(rr, 1) + x_[:, t] + ) # (2,2,BW,O') + r[t] = paddle.logsumexp(rr, 1) + x_[:, t] # compute log prefix probabilities log(psi) - log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0] + log_phi_x = paddle.concat((log_phi[0].unsqueeze(0), log_phi[:-1]), axis=0) + x_[0] if scoring_ids is not None: - log_psi = torch.full( - (n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device - ) - log_psi_ = torch.logsumexp( - torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), - dim=0, + log_psi = paddle.full((n_bh, self.odim), self.logzero, dtype=self.dtype) + log_psi_ = paddle.logsumexp( + paddle.concat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), axis=0), + axis=0, ) for si in range(n_bh): log_psi[si, scoring_ids[si]] = log_psi_[si] else: - log_psi = torch.logsumexp( - torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), - dim=0, + log_psi = paddle.logsumexp( + paddle.concat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), axis=0), + axis=0, ) for si in range(n_bh): @@ -200,7 +196,7 @@ class CTCPrefixScoreTH(): n_hyps = n_bh // self.batch vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1) # select hypothesis scores - s_new = torch.index_select(s.view(-1), 0, vidx) + s_new = paddle.index_select(s.view(-1), vidx, 0) s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim) # convert ids to BHS space (S: scoring_num) if scoring_idmap is not None: @@ -208,14 +204,14 @@ class CTCPrefixScoreTH(): hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view( -1 ) - label_ids = torch.fmod(best_ids, self.odim).view(-1) + label_ids = paddle.fmod(best_ids, self.odim).view(-1) score_idx = scoring_idmap[hyp_idx, label_ids] score_idx[score_idx == -1] = 0 vidx = score_idx + hyp_idx * snum else: snum = self.odim # select forward probabilities - r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view( + r_new = paddle.index_select(r.view(-1, 2, n_bh * snum), vidx, 2).view( -1, 2, n_bh ) return r_new, s_new, f_min, f_max @@ -223,7 +219,7 @@ class CTCPrefixScoreTH(): def extend_prob(self, x): """Extend CTC prob. - :param torch.Tensor x: input label posterior sequences (B, T, O) + :param paddle.Tensor x: input label posterior sequences (B, T, O) """ if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O) @@ -235,12 +231,12 @@ class CTCPrefixScoreTH(): x[i, l:, :] = self.logzero x[i, l:, self.blank] = 0 tmp_x = self.x - xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O) + xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O) xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) - self.x = torch.stack([xn, xb]) # (2, T, B, O) + self.x = paddle.stack([xn, xb]) # (2, T, B, O) self.x[:, : tmp_x.shape[1], :, :] = tmp_x self.input_length = x.size(1) - self.end_frames = torch.as_tensor(xlens) - 1 + self.end_frames = paddle.to_tensor(xlens) - 1 def extend_state(self, state): """Compute CTC prefix state. @@ -256,15 +252,14 @@ class CTCPrefixScoreTH(): else: r_prev, s_prev, f_min_prev, f_max_prev = state - r_prev_new = torch.full( + r_prev_new = paddle.full( (self.input_length, 2), self.logzero, dtype=self.dtype, - device=self.device, ) start = max(r_prev.shape[0], 1) r_prev_new[0:start] = r_prev - for t in six.moves.range(start, self.input_length): + for t in range(start, self.input_length): r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank] return (r_prev_new, s_prev, f_min_prev, f_max_prev) @@ -285,7 +280,7 @@ class CTCPrefixScore(): self.blank = blank self.eos = eos self.input_length = len(x) - self.x = x + self.x = x # (T, O) def initial_state(self): """Obtain an initial CTC state @@ -295,6 +290,7 @@ class CTCPrefixScore(): # initial CTC state is made of a frame x 2 tensor that corresponds to # r_t^n() and r_t^b(), where 0 and 1 of axis=1 represent # superscripts n and b (non-blank and blank), respectively. + # r shape (T, 2) r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32) r[0, 1] = self.x[0, self.blank] for i in six.moves.range(1, self.input_length): @@ -313,6 +309,7 @@ class CTCPrefixScore(): output_length = len(y) - 1 # ignore sos # new CTC states are prepared as a frame x (n or b) x n_labels tensor # that corresponds to r_t^n(h) and r_t^b(h). + # r shape (T, 2, n_labels) r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32) xs = self.x[:, cs] if output_length == 0: @@ -356,4 +353,5 @@ class CTCPrefixScore(): # return the log prefix probability and CTC states, where the label axis # of the CTC states is moved to the first axis to slice it easily + # log_psi shape (n_labels,), state shape (n_labels, T, 2) return log_psi, self.xp.rollaxis(r, 2)