update vector ctc prefix score

pull/882/head
Hui Zhang 3 years ago
parent 331bd9eaae
commit 2430545d45

@ -4,7 +4,7 @@ import numpy as np
import paddle import paddle
from .ctc_prefix_score import CTCPrefixScore from .ctc_prefix_score import CTCPrefixScore
from .ctc_prefix_score import CTCPrefixScoreTH from .ctc_prefix_score import CTCPrefixScorePD
from .scorer_interface import BatchPartialScorerInterface from .scorer_interface import BatchPartialScorerInterface
@ -34,7 +34,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
""" """
logp = self.ctc.log_softmax(x.unsqueeze(0)).squeeze(0).numpy() logp = self.ctc.log_softmax(x.unsqueeze(0)).squeeze(0).numpy()
# TODO(karita): use CTCPrefixScoreTH # TODO(karita): use CTCPrefixScorePD
self.impl = CTCPrefixScore(logp, 0, self.eos, np) self.impl = CTCPrefixScore(logp, 0, self.eos, np)
return 0, self.impl.initial_state() return 0, self.impl.initial_state()
@ -54,7 +54,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
if len(state) == 2: # for CTCPrefixScore if len(state) == 2: # for CTCPrefixScore
sc, st = state sc, st = state
return sc[i], st[i] return sc[i], st[i]
else: # for CTCPrefixScoreTH (need new_id > 0) else: # for CTCPrefixScorePD (need new_id > 0)
r, log_psi, f_min, f_max, scoring_idmap = state r, log_psi, f_min, f_max, scoring_idmap = state
s = log_psi[i, new_id].expand(log_psi.size(1)) s = log_psi[i, new_id].expand(log_psi.size(1))
if scoring_idmap is not None: if scoring_idmap is not None:
@ -96,7 +96,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
""" """
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
xlen = paddle.to_tensor([logp.size(1)]) xlen = paddle.to_tensor([logp.size(1)])
self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos)
return None return None
def batch_score_partial(self, y, ids, state, x): def batch_score_partial(self, y, ids, state, x):

@ -3,13 +3,13 @@
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori) # Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch import paddle
import numpy as np import numpy as np
import six import six
class CTCPrefixScoreTH(): class CTCPrefixScorePD():
"""Batch processing of CTCPrefixScore """Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al. which is based on Algorithm 2 in WATANABE et al.
@ -23,8 +23,10 @@ class CTCPrefixScoreTH():
def __init__(self, x, xlens, blank, eos, margin=0): def __init__(self, x, xlens, blank, eos, margin=0):
"""Construct CTC prefix scorer """Construct CTC prefix scorer
:param torch.Tensor x: input label posterior sequences (B, T, O) `margin` is M in eq.(22,23)
:param torch.Tensor xlens: input lengths (B,)
:param paddle.Tensor x: input label posterior sequences (B, T, O)
:param paddle.Tensor xlens: input lengths (B,)
:param int blank: blank label id :param int blank: blank label id
:param int eos: end-of-sequence id :param int eos: end-of-sequence id
:param int margin: margin parameter for windowing (0 means no windowing) :param int margin: margin parameter for windowing (0 means no windowing)
@ -38,11 +40,8 @@ class CTCPrefixScoreTH():
self.input_length = x.size(1) self.input_length = x.size(1)
self.odim = x.size(2) self.odim = x.size(2)
self.dtype = x.dtype self.dtype = x.dtype
self.device = ( self.device = x.place
torch.device("cuda:%d" % x.get_device())
if x.is_cuda
else torch.device("cpu")
)
# Pad the rest of posteriors in the batch # Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops # TODO(takaaki-hori): need a better way without for-loops
for i, l in enumerate(xlens): for i, l in enumerate(xlens):
@ -50,20 +49,21 @@ class CTCPrefixScoreTH():
x[i, l:, :] = self.logzero x[i, l:, :] = self.logzero
x[i, l:, blank] = 0 x[i, l:, blank] = 0
# Reshape input x # Reshape input x
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O) xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) # (T,B,O)
self.x = torch.stack([xn, xb]) # (2, T, B, O) self.x = paddle.stack([xn, xb]) # (2, T, B, O)
self.end_frames = torch.as_tensor(xlens) - 1 self.end_frames = paddle.to_tensor(xlens) - 1 # (B,)
# Setup CTC windowing # Setup CTC windowing
self.margin = margin self.margin = margin
if margin > 0: if margin > 0:
self.frame_ids = torch.arange( self.frame_ids = paddle.arange(self.input_length, dtype=self.dtype)
self.input_length, dtype=self.dtype, device=self.device
)
# Base indices for index conversion # Base indices for index conversion
self.idx_bh = None # B idx, hyp idx. shape (B*W, 1)
self.idx_b = torch.arange(self.batch, device=self.device) self.idx_bh = None
# B idx. shape (B,)
self.idx_b = paddle.arange(self.batch, place=self.device)
# B idx, O idx. shape (B, 1)
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1) self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
def __call__(self, y, state, scoring_ids=None, att_w=None): def __call__(self, y, state, scoring_ids=None, att_w=None):
@ -71,8 +71,8 @@ class CTCPrefixScoreTH():
:param list y: prefix label sequences :param list y: prefix label sequences
:param tuple state: previous CTC state :param tuple state: previous CTC state
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O) :param paddle.Tensor scoring_ids: selected next ids to score (BW, O'), O' <= O
:param torch.Tensor att_w: attention weights to decide CTC window :param paddle.Tensor att_w: attention weights to decide CTC window
:return new_state, ctc_local_scores (BW, O) :return new_state, ctc_local_scores (BW, O)
""" """
output_length = len(y[0]) - 1 # ignore sos output_length = len(y[0]) - 1 # ignore sos
@ -82,56 +82,53 @@ class CTCPrefixScoreTH():
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0 self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
# prepare state info # prepare state info
if state is None: if state is None:
r_prev = torch.full( r_prev = paddle.full(
(self.input_length, 2, self.batch, n_hyps), (self.input_length, 2, self.batch, n_hyps),
self.logzero, self.logzero,
dtype=self.dtype, dtype=self.dtype,
device=self.device, ) # (T, 2, B, W)
) r_prev[:, 1] = paddle.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2) r_prev = r_prev.view(-1, 2, n_bh) # (T, 2, BW)
r_prev = r_prev.view(-1, 2, n_bh) s_prev = 0.0 # score
s_prev = 0.0 f_min_prev = 0 # eq. 22-23
f_min_prev = 0 f_max_prev = 1 # eq. 22-23
f_max_prev = 1
else: else:
r_prev, s_prev, f_min_prev, f_max_prev = state r_prev, s_prev, f_min_prev, f_max_prev = state
# select input dimensions for scoring # select input dimensions for scoring
if self.scoring_num > 0: if self.scoring_num > 0:
scoring_idmap = torch.full( # (BW, O)
(n_bh, self.odim), -1, dtype=torch.long, device=self.device scoring_idmap = paddle.full((n_bh, self.odim), -1, dtype=paddle.long)
)
snum = self.scoring_num snum = self.scoring_num
if self.idx_bh is None or n_bh > len(self.idx_bh): if self.idx_bh is None or n_bh > len(self.idx_bh):
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1) self.idx_bh = paddle.arange(n_bh).view(-1, 1) # (BW, 1)
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange( scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = paddle.arange(snum)
snum, device=self.device
)
scoring_idx = ( scoring_idx = (
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) # (BW,1)
).view(-1) ).view(-1) # (BWO)
x_ = torch.index_select( # x_ shape (2, T, B*W, O)
self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx x_ = paddle.index_select(
self.x.view(2, -1, self.batch * self.odim), scoring_idx, 2
).view(2, -1, n_bh, snum) ).view(2, -1, n_bh, snum)
else: else:
scoring_ids = None scoring_ids = None
scoring_idmap = None scoring_idmap = None
snum = self.odim snum = self.odim
# x_ shape (2, T, B*W, O)
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum) x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch. # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
r = torch.full( r = paddle.full(
(self.input_length, 2, n_bh, snum), (self.input_length, 2, n_bh, snum),
self.logzero, self.logzero,
dtype=self.dtype, dtype=self.dtype,
device=self.device,
) )
if output_length == 0: if output_length == 0:
r[0, 0] = x_[0, 0] r[0, 0] = x_[0, 0]
r_sum = torch.logsumexp(r_prev, 1) r_sum = paddle.logsumexp(r_prev, 1) #(T,BW)
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) # (T, BW, O)
if scoring_ids is not None: if scoring_ids is not None:
for idx in range(n_bh): for idx in range(n_bh):
pos = scoring_idmap[idx, last_ids[idx]] pos = scoring_idmap[idx, last_ids[idx]]
@ -143,40 +140,39 @@ class CTCPrefixScoreTH():
# decide start and end frames based on attention weights # decide start and end frames based on attention weights
if att_w is not None and self.margin > 0: if att_w is not None and self.margin > 0:
f_arg = torch.matmul(att_w, self.frame_ids) f_arg = paddle.matmul(att_w, self.frame_ids)
f_min = max(int(f_arg.min().cpu()), f_min_prev) f_min = max(int(f_arg.min().cpu()), f_min_prev)
f_max = max(int(f_arg.max().cpu()), f_max_prev) f_max = max(int(f_arg.max().cpu()), f_max_prev)
start = min(f_max_prev, max(f_min - self.margin, output_length, 1)) start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
end = min(f_max + self.margin, self.input_length) end = min(f_max + self.margin, self.input_length)
else: else:
f_min = f_max = 0 f_min = f_max = 0
# if one frame one out, the output_length is the eating frame num now.
start = max(output_length, 1) start = max(output_length, 1)
end = self.input_length end = self.input_length
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for t in range(start, end): for t in range(start, end):
rp = r[t - 1] rp = r[t - 1] # (2 x BW x O')
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
2, 2, n_bh, snum 2, 2, n_bh, snum
) ) # (2,2,BW,O')
r[t] = torch.logsumexp(rr, 1) + x_[:, t] r[t] = paddle.logsumexp(rr, 1) + x_[:, t]
# compute log prefix probabilities log(psi) # compute log prefix probabilities log(psi)
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0] log_phi_x = paddle.concat((log_phi[0].unsqueeze(0), log_phi[:-1]), axis=0) + x_[0]
if scoring_ids is not None: if scoring_ids is not None:
log_psi = torch.full( log_psi = paddle.full((n_bh, self.odim), self.logzero, dtype=self.dtype)
(n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device log_psi_ = paddle.logsumexp(
) paddle.concat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), axis=0),
log_psi_ = torch.logsumexp( axis=0,
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
dim=0,
) )
for si in range(n_bh): for si in range(n_bh):
log_psi[si, scoring_ids[si]] = log_psi_[si] log_psi[si, scoring_ids[si]] = log_psi_[si]
else: else:
log_psi = torch.logsumexp( log_psi = paddle.logsumexp(
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), paddle.concat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), axis=0),
dim=0, axis=0,
) )
for si in range(n_bh): for si in range(n_bh):
@ -200,7 +196,7 @@ class CTCPrefixScoreTH():
n_hyps = n_bh // self.batch n_hyps = n_bh // self.batch
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1) vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
# select hypothesis scores # select hypothesis scores
s_new = torch.index_select(s.view(-1), 0, vidx) s_new = paddle.index_select(s.view(-1), vidx, 0)
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim) s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
# convert ids to BHS space (S: scoring_num) # convert ids to BHS space (S: scoring_num)
if scoring_idmap is not None: if scoring_idmap is not None:
@ -208,14 +204,14 @@ class CTCPrefixScoreTH():
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view( hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
-1 -1
) )
label_ids = torch.fmod(best_ids, self.odim).view(-1) label_ids = paddle.fmod(best_ids, self.odim).view(-1)
score_idx = scoring_idmap[hyp_idx, label_ids] score_idx = scoring_idmap[hyp_idx, label_ids]
score_idx[score_idx == -1] = 0 score_idx[score_idx == -1] = 0
vidx = score_idx + hyp_idx * snum vidx = score_idx + hyp_idx * snum
else: else:
snum = self.odim snum = self.odim
# select forward probabilities # select forward probabilities
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view( r_new = paddle.index_select(r.view(-1, 2, n_bh * snum), vidx, 2).view(
-1, 2, n_bh -1, 2, n_bh
) )
return r_new, s_new, f_min, f_max return r_new, s_new, f_min, f_max
@ -223,7 +219,7 @@ class CTCPrefixScoreTH():
def extend_prob(self, x): def extend_prob(self, x):
"""Extend CTC prob. """Extend CTC prob.
:param torch.Tensor x: input label posterior sequences (B, T, O) :param paddle.Tensor x: input label posterior sequences (B, T, O)
""" """
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O) if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
@ -235,12 +231,12 @@ class CTCPrefixScoreTH():
x[i, l:, :] = self.logzero x[i, l:, :] = self.logzero
x[i, l:, self.blank] = 0 x[i, l:, self.blank] = 0
tmp_x = self.x tmp_x = self.x
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O) xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = torch.stack([xn, xb]) # (2, T, B, O) self.x = paddle.stack([xn, xb]) # (2, T, B, O)
self.x[:, : tmp_x.shape[1], :, :] = tmp_x self.x[:, : tmp_x.shape[1], :, :] = tmp_x
self.input_length = x.size(1) self.input_length = x.size(1)
self.end_frames = torch.as_tensor(xlens) - 1 self.end_frames = paddle.to_tensor(xlens) - 1
def extend_state(self, state): def extend_state(self, state):
"""Compute CTC prefix state. """Compute CTC prefix state.
@ -256,15 +252,14 @@ class CTCPrefixScoreTH():
else: else:
r_prev, s_prev, f_min_prev, f_max_prev = state r_prev, s_prev, f_min_prev, f_max_prev = state
r_prev_new = torch.full( r_prev_new = paddle.full(
(self.input_length, 2), (self.input_length, 2),
self.logzero, self.logzero,
dtype=self.dtype, dtype=self.dtype,
device=self.device,
) )
start = max(r_prev.shape[0], 1) start = max(r_prev.shape[0], 1)
r_prev_new[0:start] = r_prev r_prev_new[0:start] = r_prev
for t in six.moves.range(start, self.input_length): for t in range(start, self.input_length):
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank] r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
return (r_prev_new, s_prev, f_min_prev, f_max_prev) return (r_prev_new, s_prev, f_min_prev, f_max_prev)
@ -285,7 +280,7 @@ class CTCPrefixScore():
self.blank = blank self.blank = blank
self.eos = eos self.eos = eos
self.input_length = len(x) self.input_length = len(x)
self.x = x self.x = x # (T, O)
def initial_state(self): def initial_state(self):
"""Obtain an initial CTC state """Obtain an initial CTC state
@ -295,6 +290,7 @@ class CTCPrefixScore():
# initial CTC state is made of a frame x 2 tensor that corresponds to # initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent # r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively. # superscripts n and b (non-blank and blank), respectively.
# r shape (T, 2)
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32) r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
r[0, 1] = self.x[0, self.blank] r[0, 1] = self.x[0, self.blank]
for i in six.moves.range(1, self.input_length): for i in six.moves.range(1, self.input_length):
@ -313,6 +309,7 @@ class CTCPrefixScore():
output_length = len(y) - 1 # ignore sos output_length = len(y) - 1 # ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor # new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h). # that corresponds to r_t^n(h) and r_t^b(h).
# r shape (T, 2, n_labels)
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32) r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
xs = self.x[:, cs] xs = self.x[:, cs]
if output_length == 0: if output_length == 0:
@ -356,4 +353,5 @@ class CTCPrefixScore():
# return the log prefix probability and CTC states, where the label axis # return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily # of the CTC states is moved to the first axis to slice it easily
# log_psi shape (n_labels,), state shape (n_labels, T, 2)
return log_psi, self.xp.rollaxis(r, 2) return log_psi, self.xp.rollaxis(r, 2)

Loading…
Cancel
Save