parent
69bd17dcb2
commit
331bd9eaae
@ -0,0 +1,13 @@
|
||||
# Decoders
|
||||
|
||||
## Reference
|
||||
### CTC Prefix Beam Search
|
||||
* [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
|
||||
* [First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs](https://arxiv.org/pdf/1408.2873.pdf)
|
||||
|
||||
### CTC Prefix Score & Join CTC/ATT One-passing Decoding
|
||||
* [Hybrid CTC/Attention Architecture for End-to-End Speech Recognition](http://www.ifp.illinois.edu/speech/speech_web_lg/slides/2019/watanabe_hybridCTCAttention_2017.pdf)
|
||||
* [Vectorized Beam Search for CTC-Attention-based Speech Recognition](https://www.isca-speech.org/archive/pdfs/interspeech_2019/seki19b_interspeech.pdf)
|
||||
|
||||
### Streaming Join CTC/ATT Beam Search
|
||||
* [STREAMING TRANSFORMER ASR WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH](https://arxiv.org/abs/2006.14941)
|
@ -0,0 +1,158 @@
|
||||
"""ScorerInterface implementation for CTC."""
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from .ctc_prefix_score import CTCPrefixScore
|
||||
from .ctc_prefix_score import CTCPrefixScoreTH
|
||||
from .scorer_interface import BatchPartialScorerInterface
|
||||
|
||||
|
||||
class CTCPrefixScorer(BatchPartialScorerInterface):
|
||||
"""Decoder interface wrapper for CTCPrefixScore."""
|
||||
|
||||
def __init__(self, ctc: paddle.nn.Layer, eos: int):
|
||||
"""Initialize class.
|
||||
|
||||
Args:
|
||||
ctc (paddle.nn.Layer): The CTC implementation.
|
||||
For example, :class:`deepspeech.modules.ctc.CTC`
|
||||
eos (int): The end-of-sequence id.
|
||||
|
||||
"""
|
||||
self.ctc = ctc
|
||||
self.eos = eos
|
||||
self.impl = None
|
||||
|
||||
def init_state(self, x: paddle.Tensor):
|
||||
"""Get an initial state for decoding.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: initial state
|
||||
|
||||
"""
|
||||
logp = self.ctc.log_softmax(x.unsqueeze(0)).squeeze(0).numpy()
|
||||
# TODO(karita): use CTCPrefixScoreTH
|
||||
self.impl = CTCPrefixScore(logp, 0, self.eos, np)
|
||||
return 0, self.impl.initial_state()
|
||||
|
||||
def select_state(self, state, i, new_id=None):
|
||||
"""Select state with relative ids in the main beam search.
|
||||
|
||||
Args:
|
||||
state: Decoder state for prefix tokens
|
||||
i (int): Index to select a state in the main beam search
|
||||
new_id (int): New label id to select a state if necessary
|
||||
|
||||
Returns:
|
||||
state: pruned state
|
||||
|
||||
"""
|
||||
if type(state) == tuple:
|
||||
if len(state) == 2: # for CTCPrefixScore
|
||||
sc, st = state
|
||||
return sc[i], st[i]
|
||||
else: # for CTCPrefixScoreTH (need new_id > 0)
|
||||
r, log_psi, f_min, f_max, scoring_idmap = state
|
||||
s = log_psi[i, new_id].expand(log_psi.size(1))
|
||||
if scoring_idmap is not None:
|
||||
return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
|
||||
else:
|
||||
return r[:, :, i, new_id], s, f_min, f_max
|
||||
return None if state is None else state[i]
|
||||
|
||||
def score_partial(self, y, ids, state, x):
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (paddle.Tensor): 1D prefix token
|
||||
next_tokens (paddle.Tensor): paddle.int64 next token to score
|
||||
state: decoder state for prefix tokens
|
||||
x (paddle.Tensor): 2D encoder feature that generates ys
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, Any]:
|
||||
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
prev_score, state = state
|
||||
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
|
||||
tscore = paddle.to_tensor(
|
||||
presub_score - prev_score, place=x.place, dtype=x.dtype
|
||||
)
|
||||
return tscore, (presub_score, new_st)
|
||||
|
||||
def batch_init_state(self, x: paddle.Tensor):
|
||||
"""Get an initial state for decoding.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: initial state
|
||||
|
||||
"""
|
||||
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
|
||||
xlen = paddle.to_tensor([logp.size(1)])
|
||||
self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
|
||||
return None
|
||||
|
||||
def batch_score_partial(self, y, ids, state, x):
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (paddle.Tensor): 1D prefix token
|
||||
ids (paddle.Tensor): paddle.int64 next token to score
|
||||
state: decoder state for prefix tokens
|
||||
x (paddle.Tensor): 2D encoder feature that generates ys
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, Any]:
|
||||
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
batch_state = (
|
||||
(
|
||||
paddle.stack([s[0] for s in state], axis=2),
|
||||
paddle.stack([s[1] for s in state]),
|
||||
state[0][2],
|
||||
state[0][3],
|
||||
)
|
||||
if state[0] is not None
|
||||
else None
|
||||
)
|
||||
return self.impl(y, batch_state, ids)
|
||||
|
||||
def extend_prob(self, x: paddle.Tensor):
|
||||
"""Extend probs for decoding.
|
||||
|
||||
This extension is for streaming decoding
|
||||
as in Eq (14) in https://arxiv.org/abs/2006.14941
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The encoded feature tensor
|
||||
|
||||
"""
|
||||
logp = self.ctc.log_softmax(x.unsqueeze(0))
|
||||
self.impl.extend_prob(logp)
|
||||
|
||||
def extend_state(self, state):
|
||||
"""Extend state for decoding.
|
||||
|
||||
This extension is for streaming decoding
|
||||
as in Eq (14) in https://arxiv.org/abs/2006.14941
|
||||
|
||||
Args:
|
||||
state: The states of hyps
|
||||
|
||||
Returns: exteded state
|
||||
|
||||
"""
|
||||
new_state = []
|
||||
for s in state:
|
||||
new_state.append(self.impl.extend_state(s))
|
||||
|
||||
return new_state
|
@ -0,0 +1,359 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
|
||||
class CTCPrefixScoreTH():
|
||||
"""Batch processing of CTCPrefixScore
|
||||
|
||||
which is based on Algorithm 2 in WATANABE et al.
|
||||
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
||||
but extended to efficiently compute the label probablities for multiple
|
||||
hypotheses simultaneously
|
||||
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
|
||||
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
|
||||
"""
|
||||
|
||||
def __init__(self, x, xlens, blank, eos, margin=0):
|
||||
"""Construct CTC prefix scorer
|
||||
|
||||
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
||||
:param torch.Tensor xlens: input lengths (B,)
|
||||
:param int blank: blank label id
|
||||
:param int eos: end-of-sequence id
|
||||
:param int margin: margin parameter for windowing (0 means no windowing)
|
||||
"""
|
||||
# In the comment lines,
|
||||
# we assume T: input_length, B: batch size, W: beam width, O: output dim.
|
||||
self.logzero = -10000000000.0
|
||||
self.blank = blank
|
||||
self.eos = eos
|
||||
self.batch = x.size(0)
|
||||
self.input_length = x.size(1)
|
||||
self.odim = x.size(2)
|
||||
self.dtype = x.dtype
|
||||
self.device = (
|
||||
torch.device("cuda:%d" % x.get_device())
|
||||
if x.is_cuda
|
||||
else torch.device("cpu")
|
||||
)
|
||||
# Pad the rest of posteriors in the batch
|
||||
# TODO(takaaki-hori): need a better way without for-loops
|
||||
for i, l in enumerate(xlens):
|
||||
if l < self.input_length:
|
||||
x[i, l:, :] = self.logzero
|
||||
x[i, l:, blank] = 0
|
||||
# Reshape input x
|
||||
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
||||
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
||||
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
||||
self.end_frames = torch.as_tensor(xlens) - 1
|
||||
|
||||
# Setup CTC windowing
|
||||
self.margin = margin
|
||||
if margin > 0:
|
||||
self.frame_ids = torch.arange(
|
||||
self.input_length, dtype=self.dtype, device=self.device
|
||||
)
|
||||
# Base indices for index conversion
|
||||
self.idx_bh = None
|
||||
self.idx_b = torch.arange(self.batch, device=self.device)
|
||||
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
|
||||
|
||||
def __call__(self, y, state, scoring_ids=None, att_w=None):
|
||||
"""Compute CTC prefix scores for next labels
|
||||
|
||||
:param list y: prefix label sequences
|
||||
:param tuple state: previous CTC state
|
||||
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
|
||||
:param torch.Tensor att_w: attention weights to decide CTC window
|
||||
:return new_state, ctc_local_scores (BW, O)
|
||||
"""
|
||||
output_length = len(y[0]) - 1 # ignore sos
|
||||
last_ids = [yi[-1] for yi in y] # last output label ids
|
||||
n_bh = len(last_ids) # batch * hyps
|
||||
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
|
||||
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
|
||||
# prepare state info
|
||||
if state is None:
|
||||
r_prev = torch.full(
|
||||
(self.input_length, 2, self.batch, n_hyps),
|
||||
self.logzero,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
|
||||
r_prev = r_prev.view(-1, 2, n_bh)
|
||||
s_prev = 0.0
|
||||
f_min_prev = 0
|
||||
f_max_prev = 1
|
||||
else:
|
||||
r_prev, s_prev, f_min_prev, f_max_prev = state
|
||||
|
||||
# select input dimensions for scoring
|
||||
if self.scoring_num > 0:
|
||||
scoring_idmap = torch.full(
|
||||
(n_bh, self.odim), -1, dtype=torch.long, device=self.device
|
||||
)
|
||||
snum = self.scoring_num
|
||||
if self.idx_bh is None or n_bh > len(self.idx_bh):
|
||||
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
|
||||
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
|
||||
snum, device=self.device
|
||||
)
|
||||
scoring_idx = (
|
||||
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
|
||||
).view(-1)
|
||||
x_ = torch.index_select(
|
||||
self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
|
||||
).view(2, -1, n_bh, snum)
|
||||
else:
|
||||
scoring_ids = None
|
||||
scoring_idmap = None
|
||||
snum = self.odim
|
||||
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
|
||||
|
||||
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
|
||||
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
|
||||
r = torch.full(
|
||||
(self.input_length, 2, n_bh, snum),
|
||||
self.logzero,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if output_length == 0:
|
||||
r[0, 0] = x_[0, 0]
|
||||
|
||||
r_sum = torch.logsumexp(r_prev, 1)
|
||||
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
|
||||
if scoring_ids is not None:
|
||||
for idx in range(n_bh):
|
||||
pos = scoring_idmap[idx, last_ids[idx]]
|
||||
if pos >= 0:
|
||||
log_phi[:, idx, pos] = r_prev[:, 1, idx]
|
||||
else:
|
||||
for idx in range(n_bh):
|
||||
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
|
||||
|
||||
# decide start and end frames based on attention weights
|
||||
if att_w is not None and self.margin > 0:
|
||||
f_arg = torch.matmul(att_w, self.frame_ids)
|
||||
f_min = max(int(f_arg.min().cpu()), f_min_prev)
|
||||
f_max = max(int(f_arg.max().cpu()), f_max_prev)
|
||||
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
|
||||
end = min(f_max + self.margin, self.input_length)
|
||||
else:
|
||||
f_min = f_max = 0
|
||||
start = max(output_length, 1)
|
||||
end = self.input_length
|
||||
|
||||
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
|
||||
for t in range(start, end):
|
||||
rp = r[t - 1]
|
||||
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
|
||||
2, 2, n_bh, snum
|
||||
)
|
||||
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
|
||||
|
||||
# compute log prefix probabilities log(psi)
|
||||
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
|
||||
if scoring_ids is not None:
|
||||
log_psi = torch.full(
|
||||
(n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
|
||||
)
|
||||
log_psi_ = torch.logsumexp(
|
||||
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
||||
dim=0,
|
||||
)
|
||||
for si in range(n_bh):
|
||||
log_psi[si, scoring_ids[si]] = log_psi_[si]
|
||||
else:
|
||||
log_psi = torch.logsumexp(
|
||||
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
for si in range(n_bh):
|
||||
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
|
||||
|
||||
# exclude blank probs
|
||||
log_psi[:, self.blank] = self.logzero
|
||||
|
||||
return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
|
||||
|
||||
def index_select_state(self, state, best_ids):
|
||||
"""Select CTC states according to best ids
|
||||
|
||||
:param state : CTC state
|
||||
:param best_ids : index numbers selected by beam pruning (B, W)
|
||||
:return selected_state
|
||||
"""
|
||||
r, s, f_min, f_max, scoring_idmap = state
|
||||
# convert ids to BHO space
|
||||
n_bh = len(s)
|
||||
n_hyps = n_bh // self.batch
|
||||
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
|
||||
# select hypothesis scores
|
||||
s_new = torch.index_select(s.view(-1), 0, vidx)
|
||||
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
|
||||
# convert ids to BHS space (S: scoring_num)
|
||||
if scoring_idmap is not None:
|
||||
snum = self.scoring_num
|
||||
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
|
||||
-1
|
||||
)
|
||||
label_ids = torch.fmod(best_ids, self.odim).view(-1)
|
||||
score_idx = scoring_idmap[hyp_idx, label_ids]
|
||||
score_idx[score_idx == -1] = 0
|
||||
vidx = score_idx + hyp_idx * snum
|
||||
else:
|
||||
snum = self.odim
|
||||
# select forward probabilities
|
||||
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
|
||||
-1, 2, n_bh
|
||||
)
|
||||
return r_new, s_new, f_min, f_max
|
||||
|
||||
def extend_prob(self, x):
|
||||
"""Extend CTC prob.
|
||||
|
||||
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
||||
"""
|
||||
|
||||
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
|
||||
# Pad the rest of posteriors in the batch
|
||||
# TODO(takaaki-hori): need a better way without for-loops
|
||||
xlens = [x.size(1)]
|
||||
for i, l in enumerate(xlens):
|
||||
if l < self.input_length:
|
||||
x[i, l:, :] = self.logzero
|
||||
x[i, l:, self.blank] = 0
|
||||
tmp_x = self.x
|
||||
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
||||
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
||||
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
||||
self.x[:, : tmp_x.shape[1], :, :] = tmp_x
|
||||
self.input_length = x.size(1)
|
||||
self.end_frames = torch.as_tensor(xlens) - 1
|
||||
|
||||
def extend_state(self, state):
|
||||
"""Compute CTC prefix state.
|
||||
|
||||
|
||||
:param state : CTC state
|
||||
:return ctc_state
|
||||
"""
|
||||
|
||||
if state is None:
|
||||
# nothing to do
|
||||
return state
|
||||
else:
|
||||
r_prev, s_prev, f_min_prev, f_max_prev = state
|
||||
|
||||
r_prev_new = torch.full(
|
||||
(self.input_length, 2),
|
||||
self.logzero,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
start = max(r_prev.shape[0], 1)
|
||||
r_prev_new[0:start] = r_prev
|
||||
for t in six.moves.range(start, self.input_length):
|
||||
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
|
||||
|
||||
return (r_prev_new, s_prev, f_min_prev, f_max_prev)
|
||||
|
||||
|
||||
class CTCPrefixScore():
|
||||
"""Compute CTC label sequence scores
|
||||
|
||||
which is based on Algorithm 2 in WATANABE et al.
|
||||
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
||||
but extended to efficiently compute the probablities of multiple labels
|
||||
simultaneously
|
||||
"""
|
||||
|
||||
def __init__(self, x, blank, eos, xp):
|
||||
self.xp = xp
|
||||
self.logzero = -10000000000.0
|
||||
self.blank = blank
|
||||
self.eos = eos
|
||||
self.input_length = len(x)
|
||||
self.x = x
|
||||
|
||||
def initial_state(self):
|
||||
"""Obtain an initial CTC state
|
||||
|
||||
:return: CTC state
|
||||
"""
|
||||
# initial CTC state is made of a frame x 2 tensor that corresponds to
|
||||
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
|
||||
# superscripts n and b (non-blank and blank), respectively.
|
||||
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
|
||||
r[0, 1] = self.x[0, self.blank]
|
||||
for i in six.moves.range(1, self.input_length):
|
||||
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
|
||||
return r
|
||||
|
||||
def __call__(self, y, cs, r_prev):
|
||||
"""Compute CTC prefix scores for next labels
|
||||
|
||||
:param y : prefix label sequence
|
||||
:param cs : array of next labels
|
||||
:param r_prev: previous CTC state
|
||||
:return ctc_scores, ctc_states
|
||||
"""
|
||||
# initialize CTC states
|
||||
output_length = len(y) - 1 # ignore sos
|
||||
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
|
||||
# that corresponds to r_t^n(h) and r_t^b(h).
|
||||
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
|
||||
xs = self.x[:, cs]
|
||||
if output_length == 0:
|
||||
r[0, 0] = xs[0]
|
||||
r[0, 1] = self.logzero
|
||||
else:
|
||||
r[output_length - 1] = self.logzero
|
||||
|
||||
# prepare forward probabilities for the last label
|
||||
r_sum = self.xp.logaddexp(
|
||||
r_prev[:, 0], r_prev[:, 1]
|
||||
) # log(r_t^n(g) + r_t^b(g))
|
||||
last = y[-1]
|
||||
if output_length > 0 and last in cs:
|
||||
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
|
||||
for i in six.moves.range(len(cs)):
|
||||
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
|
||||
else:
|
||||
log_phi = r_sum
|
||||
|
||||
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
|
||||
# and log prefix probabilities log(psi)
|
||||
start = max(output_length, 1)
|
||||
log_psi = r[start - 1, 0]
|
||||
for t in six.moves.range(start, self.input_length):
|
||||
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
|
||||
r[t, 1] = (
|
||||
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
|
||||
)
|
||||
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
|
||||
|
||||
# get P(...eos|X) that ends with the prefix itself
|
||||
eos_pos = self.xp.where(cs == self.eos)[0]
|
||||
if len(eos_pos) > 0:
|
||||
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
|
||||
|
||||
# exclude blank probs
|
||||
blank_pos = self.xp.where(cs == self.blank)[0]
|
||||
if len(blank_pos) > 0:
|
||||
log_psi[blank_pos] = self.logzero
|
||||
|
||||
# return the log prefix probability and CTC states, where the label axis
|
||||
# of the CTC states is moved to the first axis to slice it easily
|
||||
return log_psi, self.xp.rollaxis(r, 2)
|
@ -0,0 +1,61 @@
|
||||
"""Length bonus module."""
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
|
||||
from .scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class LengthBonus(BatchScorerInterface):
|
||||
"""Length bonus in beam search."""
|
||||
|
||||
def __init__(self, n_vocab: int):
|
||||
"""Initialize class.
|
||||
|
||||
Args:
|
||||
n_vocab (int): The number of tokens in vocabulary for beam search
|
||||
|
||||
"""
|
||||
self.n = n_vocab
|
||||
|
||||
def score(self, y, state, x):
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (paddle.Tensor): 2D encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, Any]: Tuple of
|
||||
paddle.float32 scores for next token (n_vocab)
|
||||
and None
|
||||
|
||||
"""
|
||||
return paddle.to_tensor([1.0], place=x.place, dtype=x.dtype).expand(self.n), None
|
||||
|
||||
def batch_score(
|
||||
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (paddle.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
return (
|
||||
paddle.to_tensor([1.0], place=xs.place, dtype=xs.dtype).expand(
|
||||
ys.shape[0], self.n
|
||||
),
|
||||
None,
|
||||
)
|
@ -0,0 +1,102 @@
|
||||
"""Ngram lm implement."""
|
||||
|
||||
from abc import ABC
|
||||
|
||||
import kenlm
|
||||
import paddle
|
||||
|
||||
from .scorer_interface import BatchScorerInterface
|
||||
from .scorer_interface import PartialScorerInterface
|
||||
|
||||
|
||||
class Ngrambase(ABC):
|
||||
"""Ngram base implemented through ScorerInterface."""
|
||||
|
||||
def __init__(self, ngram_model, token_list):
|
||||
"""Initialize Ngrambase.
|
||||
|
||||
Args:
|
||||
ngram_model: ngram model path
|
||||
token_list: token list from dict or model.json
|
||||
|
||||
"""
|
||||
self.chardict = [x if x != "<eos>" else "</s>" for x in token_list]
|
||||
self.charlen = len(self.chardict)
|
||||
self.lm = kenlm.LanguageModel(ngram_model)
|
||||
self.tmpkenlmstate = kenlm.State()
|
||||
|
||||
def init_state(self, x):
|
||||
"""Initialize tmp state."""
|
||||
state = kenlm.State()
|
||||
self.lm.NullContextWrite(state)
|
||||
return state
|
||||
|
||||
def score_partial_(self, y, next_token, state, x):
|
||||
"""Score interface for both full and partial scorer.
|
||||
|
||||
Args:
|
||||
y: previous char
|
||||
next_token: next token need to be score
|
||||
state: previous state
|
||||
x: encoded feature
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
out_state = kenlm.State()
|
||||
ys = self.chardict[y[-1]] if y.shape[0] > 1 else "<s>"
|
||||
self.lm.BaseScore(state, ys, out_state)
|
||||
scores = paddle.empty_like(next_token, dtype=x.dtype)
|
||||
for i, j in enumerate(next_token):
|
||||
scores[i] = self.lm.BaseScore(
|
||||
out_state, self.chardict[j], self.tmpkenlmstate
|
||||
)
|
||||
return scores, out_state
|
||||
|
||||
|
||||
class NgramFullScorer(Ngrambase, BatchScorerInterface):
|
||||
"""Fullscorer for ngram."""
|
||||
|
||||
def score(self, y, state, x):
|
||||
"""Score interface for both full and partial scorer.
|
||||
|
||||
Args:
|
||||
y: previous char
|
||||
state: previous state
|
||||
x: encoded feature
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
return self.score_partial_(y, paddle.to_tensor(range(self.charlen)), state, x)
|
||||
|
||||
|
||||
class NgramPartScorer(Ngrambase, PartialScorerInterface):
|
||||
"""Partialscorer for ngram."""
|
||||
|
||||
def score_partial(self, y, next_token, state, x):
|
||||
"""Score interface for both full and partial scorer.
|
||||
|
||||
Args:
|
||||
y: previous char
|
||||
next_token: next token need to be score
|
||||
state: previous state
|
||||
x: encoded feature
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
return self.score_partial_(y, next_token, state, x)
|
||||
|
||||
def select_state(self, state, i):
|
||||
"""Empty select state for scorer interface."""
|
||||
return state
|
@ -0,0 +1,188 @@
|
||||
"""Scorer interface module."""
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
import warnings
|
||||
|
||||
|
||||
class ScorerInterface:
|
||||
"""Scorer interface for beam search.
|
||||
|
||||
The scorer performs scoring of the all tokens in vocabulary.
|
||||
|
||||
Examples:
|
||||
* Search heuristics
|
||||
* :class:`scorers.length_bonus.LengthBonus`
|
||||
* Decoder networks of the sequence-to-sequence models
|
||||
* :class:`transformer.decoder.Decoder`
|
||||
* :class:`rnn.decoders.Decoder`
|
||||
* Neural language models
|
||||
* :class:`lm.transformer.TransformerLM`
|
||||
* :class:`lm.default.DefaultRNNLM`
|
||||
* :class:`lm.seq_rnn.SequentialRNNLM`
|
||||
|
||||
"""
|
||||
|
||||
def init_state(self, x: paddle.Tensor) -> Any:
|
||||
"""Get an initial state for decoding (optional).
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: initial state
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
|
||||
"""Select state with relative ids in the main beam search.
|
||||
|
||||
Args:
|
||||
state: Decoder state for prefix tokens
|
||||
i (int): Index to select a state in the main beam search
|
||||
new_id (int): New label index to select a state if necessary
|
||||
|
||||
Returns:
|
||||
state: pruned state
|
||||
|
||||
"""
|
||||
return None if state is None else state[i]
|
||||
|
||||
def score(
|
||||
self, y: paddle.Tensor, state: Any, x: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, Any]:
|
||||
"""Score new token (required).
|
||||
|
||||
Args:
|
||||
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (paddle.Tensor): The encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, Any]: Tuple of
|
||||
scores for next token that has a shape of `(n_vocab)`
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def final_score(self, state: Any) -> float:
|
||||
"""Score eos (optional).
|
||||
|
||||
Args:
|
||||
state: Scorer state for prefix tokens
|
||||
|
||||
Returns:
|
||||
float: final score
|
||||
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
|
||||
class BatchScorerInterface(ScorerInterface):
|
||||
"""Batch scorer interface."""
|
||||
|
||||
def batch_init_state(self, x: paddle.Tensor) -> Any:
|
||||
"""Get an initial state for decoding (optional).
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: initial state
|
||||
|
||||
"""
|
||||
return self.init_state(x)
|
||||
|
||||
def batch_score(
|
||||
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, List[Any]]:
|
||||
"""Score new token batch (required).
|
||||
|
||||
Args:
|
||||
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (paddle.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"{} batch score is implemented through for loop not parallelized".format(
|
||||
self.__class__.__name__
|
||||
)
|
||||
)
|
||||
scores = list()
|
||||
outstates = list()
|
||||
for i, (y, state, x) in enumerate(zip(ys, states, xs)):
|
||||
score, outstate = self.score(y, state, x)
|
||||
outstates.append(outstate)
|
||||
scores.append(score)
|
||||
scores = paddle.cat(scores, 0).view(ys.shape[0], -1)
|
||||
return scores, outstates
|
||||
|
||||
|
||||
class PartialScorerInterface(ScorerInterface):
|
||||
"""Partial scorer interface for beam search.
|
||||
|
||||
The partial scorer performs scoring when non-partial scorer finished scoring,
|
||||
and receives pre-pruned next tokens to score because it is too heavy to score
|
||||
all the tokens.
|
||||
|
||||
Examples:
|
||||
* Prefix search for connectionist-temporal-classification models
|
||||
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
|
||||
|
||||
"""
|
||||
|
||||
def score_partial(
|
||||
self, y: paddle.Tensor, next_tokens: paddle.Tensor, state: Any, x: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, Any]:
|
||||
"""Score new token (required).
|
||||
|
||||
Args:
|
||||
y (paddle.Tensor): 1D prefix token
|
||||
next_tokens (paddle.Tensor): paddle.int64 next token to score
|
||||
state: decoder state for prefix tokens
|
||||
x (paddle.Tensor): The encoder feature that generates ys
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, Any]:
|
||||
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
|
||||
"""Batch partial scorer interface for beam search."""
|
||||
|
||||
def batch_score_partial(
|
||||
self,
|
||||
ys: paddle.Tensor,
|
||||
next_tokens: paddle.Tensor,
|
||||
states: List[Any],
|
||||
xs: paddle.Tensor,
|
||||
) -> Tuple[paddle.Tensor, Any]:
|
||||
"""Score new token (required).
|
||||
|
||||
Args:
|
||||
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
|
||||
next_tokens (paddle.Tensor): paddle.int64 tokens to score (n_batch, n_token).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (paddle.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, Any]:
|
||||
Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
|
||||
and next states for ys
|
||||
"""
|
||||
raise NotImplementedError
|
Loading…
Reference in new issue