diff --git a/deepspeech/decoders/README.md b/deepspeech/decoders/README.md new file mode 100644 index 000000000..11eb56063 --- /dev/null +++ b/deepspeech/decoders/README.md @@ -0,0 +1,13 @@ +# Decoders + +## Reference +### CTC Prefix Beam Search +* [Sequence Modeling With CTC](https://distill.pub/2017/ctc/) +* [First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs](https://arxiv.org/pdf/1408.2873.pdf) + +### CTC Prefix Score & Join CTC/ATT One-passing Decoding +* [Hybrid CTC/Attention Architecture for End-to-End Speech Recognition](http://www.ifp.illinois.edu/speech/speech_web_lg/slides/2019/watanabe_hybridCTCAttention_2017.pdf) +* [Vectorized Beam Search for CTC-Attention-based Speech Recognition](https://www.isca-speech.org/archive/pdfs/interspeech_2019/seki19b_interspeech.pdf) + +### Streaming Join CTC/ATT Beam Search +* [STREAMING TRANSFORMER ASR WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH](https://arxiv.org/abs/2006.14941) \ No newline at end of file diff --git a/deepspeech/decoders/scores/__init__.py b/deepspeech/decoders/scores/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/deepspeech/decoders/scores/ctc.py b/deepspeech/decoders/scores/ctc.py new file mode 100644 index 000000000..6e5a4c538 --- /dev/null +++ b/deepspeech/decoders/scores/ctc.py @@ -0,0 +1,158 @@ +"""ScorerInterface implementation for CTC.""" + +import numpy as np +import paddle + +from .ctc_prefix_score import CTCPrefixScore +from .ctc_prefix_score import CTCPrefixScoreTH +from .scorer_interface import BatchPartialScorerInterface + + +class CTCPrefixScorer(BatchPartialScorerInterface): + """Decoder interface wrapper for CTCPrefixScore.""" + + def __init__(self, ctc: paddle.nn.Layer, eos: int): + """Initialize class. + + Args: + ctc (paddle.nn.Layer): The CTC implementation. + For example, :class:`deepspeech.modules.ctc.CTC` + eos (int): The end-of-sequence id. + + """ + self.ctc = ctc + self.eos = eos + self.impl = None + + def init_state(self, x: paddle.Tensor): + """Get an initial state for decoding. + + Args: + x (paddle.Tensor): The encoded feature tensor + + Returns: initial state + + """ + logp = self.ctc.log_softmax(x.unsqueeze(0)).squeeze(0).numpy() + # TODO(karita): use CTCPrefixScoreTH + self.impl = CTCPrefixScore(logp, 0, self.eos, np) + return 0, self.impl.initial_state() + + def select_state(self, state, i, new_id=None): + """Select state with relative ids in the main beam search. + + Args: + state: Decoder state for prefix tokens + i (int): Index to select a state in the main beam search + new_id (int): New label id to select a state if necessary + + Returns: + state: pruned state + + """ + if type(state) == tuple: + if len(state) == 2: # for CTCPrefixScore + sc, st = state + return sc[i], st[i] + else: # for CTCPrefixScoreTH (need new_id > 0) + r, log_psi, f_min, f_max, scoring_idmap = state + s = log_psi[i, new_id].expand(log_psi.size(1)) + if scoring_idmap is not None: + return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max + else: + return r[:, :, i, new_id], s, f_min, f_max + return None if state is None else state[i] + + def score_partial(self, y, ids, state, x): + """Score new token. + + Args: + y (paddle.Tensor): 1D prefix token + next_tokens (paddle.Tensor): paddle.int64 next token to score + state: decoder state for prefix tokens + x (paddle.Tensor): 2D encoder feature that generates ys + + Returns: + tuple[paddle.Tensor, Any]: + Tuple of a score tensor for y that has a shape `(len(next_tokens),)` + and next state for ys + + """ + prev_score, state = state + presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) + tscore = paddle.to_tensor( + presub_score - prev_score, place=x.place, dtype=x.dtype + ) + return tscore, (presub_score, new_st) + + def batch_init_state(self, x: paddle.Tensor): + """Get an initial state for decoding. + + Args: + x (paddle.Tensor): The encoded feature tensor + + Returns: initial state + + """ + logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 + xlen = paddle.to_tensor([logp.size(1)]) + self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) + return None + + def batch_score_partial(self, y, ids, state, x): + """Score new token. + + Args: + y (paddle.Tensor): 1D prefix token + ids (paddle.Tensor): paddle.int64 next token to score + state: decoder state for prefix tokens + x (paddle.Tensor): 2D encoder feature that generates ys + + Returns: + tuple[paddle.Tensor, Any]: + Tuple of a score tensor for y that has a shape `(len(next_tokens),)` + and next state for ys + + """ + batch_state = ( + ( + paddle.stack([s[0] for s in state], axis=2), + paddle.stack([s[1] for s in state]), + state[0][2], + state[0][3], + ) + if state[0] is not None + else None + ) + return self.impl(y, batch_state, ids) + + def extend_prob(self, x: paddle.Tensor): + """Extend probs for decoding. + + This extension is for streaming decoding + as in Eq (14) in https://arxiv.org/abs/2006.14941 + + Args: + x (paddle.Tensor): The encoded feature tensor + + """ + logp = self.ctc.log_softmax(x.unsqueeze(0)) + self.impl.extend_prob(logp) + + def extend_state(self, state): + """Extend state for decoding. + + This extension is for streaming decoding + as in Eq (14) in https://arxiv.org/abs/2006.14941 + + Args: + state: The states of hyps + + Returns: exteded state + + """ + new_state = [] + for s in state: + new_state.append(self.impl.extend_state(s)) + + return new_state diff --git a/deepspeech/decoders/scores/ctc_prefix_score.py b/deepspeech/decoders/scores/ctc_prefix_score.py new file mode 100644 index 000000000..2ca00ebf0 --- /dev/null +++ b/deepspeech/decoders/scores/ctc_prefix_score.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch + +import numpy as np +import six + + +class CTCPrefixScoreTH(): + """Batch processing of CTCPrefixScore + + which is based on Algorithm 2 in WATANABE et al. + "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," + but extended to efficiently compute the label probablities for multiple + hypotheses simultaneously + See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based + Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019. + """ + + def __init__(self, x, xlens, blank, eos, margin=0): + """Construct CTC prefix scorer + + :param torch.Tensor x: input label posterior sequences (B, T, O) + :param torch.Tensor xlens: input lengths (B,) + :param int blank: blank label id + :param int eos: end-of-sequence id + :param int margin: margin parameter for windowing (0 means no windowing) + """ + # In the comment lines, + # we assume T: input_length, B: batch size, W: beam width, O: output dim. + self.logzero = -10000000000.0 + self.blank = blank + self.eos = eos + self.batch = x.size(0) + self.input_length = x.size(1) + self.odim = x.size(2) + self.dtype = x.dtype + self.device = ( + torch.device("cuda:%d" % x.get_device()) + if x.is_cuda + else torch.device("cpu") + ) + # Pad the rest of posteriors in the batch + # TODO(takaaki-hori): need a better way without for-loops + for i, l in enumerate(xlens): + if l < self.input_length: + x[i, l:, :] = self.logzero + x[i, l:, blank] = 0 + # Reshape input x + xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O) + xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) + self.x = torch.stack([xn, xb]) # (2, T, B, O) + self.end_frames = torch.as_tensor(xlens) - 1 + + # Setup CTC windowing + self.margin = margin + if margin > 0: + self.frame_ids = torch.arange( + self.input_length, dtype=self.dtype, device=self.device + ) + # Base indices for index conversion + self.idx_bh = None + self.idx_b = torch.arange(self.batch, device=self.device) + self.idx_bo = (self.idx_b * self.odim).unsqueeze(1) + + def __call__(self, y, state, scoring_ids=None, att_w=None): + """Compute CTC prefix scores for next labels + + :param list y: prefix label sequences + :param tuple state: previous CTC state + :param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O) + :param torch.Tensor att_w: attention weights to decide CTC window + :return new_state, ctc_local_scores (BW, O) + """ + output_length = len(y[0]) - 1 # ignore sos + last_ids = [yi[-1] for yi in y] # last output label ids + n_bh = len(last_ids) # batch * hyps + n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps + self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0 + # prepare state info + if state is None: + r_prev = torch.full( + (self.input_length, 2, self.batch, n_hyps), + self.logzero, + dtype=self.dtype, + device=self.device, + ) + r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2) + r_prev = r_prev.view(-1, 2, n_bh) + s_prev = 0.0 + f_min_prev = 0 + f_max_prev = 1 + else: + r_prev, s_prev, f_min_prev, f_max_prev = state + + # select input dimensions for scoring + if self.scoring_num > 0: + scoring_idmap = torch.full( + (n_bh, self.odim), -1, dtype=torch.long, device=self.device + ) + snum = self.scoring_num + if self.idx_bh is None or n_bh > len(self.idx_bh): + self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1) + scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange( + snum, device=self.device + ) + scoring_idx = ( + scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) + ).view(-1) + x_ = torch.index_select( + self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx + ).view(2, -1, n_bh, snum) + else: + scoring_ids = None + scoring_idmap = None + snum = self.odim + x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum) + + # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor + # that corresponds to r_t^n(h) and r_t^b(h) in a batch. + r = torch.full( + (self.input_length, 2, n_bh, snum), + self.logzero, + dtype=self.dtype, + device=self.device, + ) + if output_length == 0: + r[0, 0] = x_[0, 0] + + r_sum = torch.logsumexp(r_prev, 1) + log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) + if scoring_ids is not None: + for idx in range(n_bh): + pos = scoring_idmap[idx, last_ids[idx]] + if pos >= 0: + log_phi[:, idx, pos] = r_prev[:, 1, idx] + else: + for idx in range(n_bh): + log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx] + + # decide start and end frames based on attention weights + if att_w is not None and self.margin > 0: + f_arg = torch.matmul(att_w, self.frame_ids) + f_min = max(int(f_arg.min().cpu()), f_min_prev) + f_max = max(int(f_arg.max().cpu()), f_max_prev) + start = min(f_max_prev, max(f_min - self.margin, output_length, 1)) + end = min(f_max + self.margin, self.input_length) + else: + f_min = f_max = 0 + start = max(output_length, 1) + end = self.input_length + + # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) + for t in range(start, end): + rp = r[t - 1] + rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( + 2, 2, n_bh, snum + ) + r[t] = torch.logsumexp(rr, 1) + x_[:, t] + + # compute log prefix probabilities log(psi) + log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0] + if scoring_ids is not None: + log_psi = torch.full( + (n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device + ) + log_psi_ = torch.logsumexp( + torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), + dim=0, + ) + for si in range(n_bh): + log_psi[si, scoring_ids[si]] = log_psi_[si] + else: + log_psi = torch.logsumexp( + torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), + dim=0, + ) + + for si in range(n_bh): + log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si] + + # exclude blank probs + log_psi[:, self.blank] = self.logzero + + return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap) + + def index_select_state(self, state, best_ids): + """Select CTC states according to best ids + + :param state : CTC state + :param best_ids : index numbers selected by beam pruning (B, W) + :return selected_state + """ + r, s, f_min, f_max, scoring_idmap = state + # convert ids to BHO space + n_bh = len(s) + n_hyps = n_bh // self.batch + vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1) + # select hypothesis scores + s_new = torch.index_select(s.view(-1), 0, vidx) + s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim) + # convert ids to BHS space (S: scoring_num) + if scoring_idmap is not None: + snum = self.scoring_num + hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view( + -1 + ) + label_ids = torch.fmod(best_ids, self.odim).view(-1) + score_idx = scoring_idmap[hyp_idx, label_ids] + score_idx[score_idx == -1] = 0 + vidx = score_idx + hyp_idx * snum + else: + snum = self.odim + # select forward probabilities + r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view( + -1, 2, n_bh + ) + return r_new, s_new, f_min, f_max + + def extend_prob(self, x): + """Extend CTC prob. + + :param torch.Tensor x: input label posterior sequences (B, T, O) + """ + + if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O) + # Pad the rest of posteriors in the batch + # TODO(takaaki-hori): need a better way without for-loops + xlens = [x.size(1)] + for i, l in enumerate(xlens): + if l < self.input_length: + x[i, l:, :] = self.logzero + x[i, l:, self.blank] = 0 + tmp_x = self.x + xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O) + xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) + self.x = torch.stack([xn, xb]) # (2, T, B, O) + self.x[:, : tmp_x.shape[1], :, :] = tmp_x + self.input_length = x.size(1) + self.end_frames = torch.as_tensor(xlens) - 1 + + def extend_state(self, state): + """Compute CTC prefix state. + + + :param state : CTC state + :return ctc_state + """ + + if state is None: + # nothing to do + return state + else: + r_prev, s_prev, f_min_prev, f_max_prev = state + + r_prev_new = torch.full( + (self.input_length, 2), + self.logzero, + dtype=self.dtype, + device=self.device, + ) + start = max(r_prev.shape[0], 1) + r_prev_new[0:start] = r_prev + for t in six.moves.range(start, self.input_length): + r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank] + + return (r_prev_new, s_prev, f_min_prev, f_max_prev) + + +class CTCPrefixScore(): + """Compute CTC label sequence scores + + which is based on Algorithm 2 in WATANABE et al. + "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," + but extended to efficiently compute the probablities of multiple labels + simultaneously + """ + + def __init__(self, x, blank, eos, xp): + self.xp = xp + self.logzero = -10000000000.0 + self.blank = blank + self.eos = eos + self.input_length = len(x) + self.x = x + + def initial_state(self): + """Obtain an initial CTC state + + :return: CTC state + """ + # initial CTC state is made of a frame x 2 tensor that corresponds to + # r_t^n() and r_t^b(), where 0 and 1 of axis=1 represent + # superscripts n and b (non-blank and blank), respectively. + r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32) + r[0, 1] = self.x[0, self.blank] + for i in six.moves.range(1, self.input_length): + r[i, 1] = r[i - 1, 1] + self.x[i, self.blank] + return r + + def __call__(self, y, cs, r_prev): + """Compute CTC prefix scores for next labels + + :param y : prefix label sequence + :param cs : array of next labels + :param r_prev: previous CTC state + :return ctc_scores, ctc_states + """ + # initialize CTC states + output_length = len(y) - 1 # ignore sos + # new CTC states are prepared as a frame x (n or b) x n_labels tensor + # that corresponds to r_t^n(h) and r_t^b(h). + r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32) + xs = self.x[:, cs] + if output_length == 0: + r[0, 0] = xs[0] + r[0, 1] = self.logzero + else: + r[output_length - 1] = self.logzero + + # prepare forward probabilities for the last label + r_sum = self.xp.logaddexp( + r_prev[:, 0], r_prev[:, 1] + ) # log(r_t^n(g) + r_t^b(g)) + last = y[-1] + if output_length > 0 and last in cs: + log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32) + for i in six.moves.range(len(cs)): + log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1] + else: + log_phi = r_sum + + # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)), + # and log prefix probabilities log(psi) + start = max(output_length, 1) + log_psi = r[start - 1, 0] + for t in six.moves.range(start, self.input_length): + r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t] + r[t, 1] = ( + self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank] + ) + log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t]) + + # get P(...eos|X) that ends with the prefix itself + eos_pos = self.xp.where(cs == self.eos)[0] + if len(eos_pos) > 0: + log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g)) + + # exclude blank probs + blank_pos = self.xp.where(cs == self.blank)[0] + if len(blank_pos) > 0: + log_psi[blank_pos] = self.logzero + + # return the log prefix probability and CTC states, where the label axis + # of the CTC states is moved to the first axis to slice it easily + return log_psi, self.xp.rollaxis(r, 2) diff --git a/deepspeech/decoders/scores/length_bonus.py b/deepspeech/decoders/scores/length_bonus.py new file mode 100644 index 000000000..76f45f801 --- /dev/null +++ b/deepspeech/decoders/scores/length_bonus.py @@ -0,0 +1,61 @@ +"""Length bonus module.""" +from typing import Any +from typing import List +from typing import Tuple + +import paddle + +from .scorer_interface import BatchScorerInterface + + +class LengthBonus(BatchScorerInterface): + """Length bonus in beam search.""" + + def __init__(self, n_vocab: int): + """Initialize class. + + Args: + n_vocab (int): The number of tokens in vocabulary for beam search + + """ + self.n = n_vocab + + def score(self, y, state, x): + """Score new token. + + Args: + y (paddle.Tensor): 1D paddle.int64 prefix tokens. + state: Scorer state for prefix tokens + x (paddle.Tensor): 2D encoder feature that generates ys. + + Returns: + tuple[paddle.Tensor, Any]: Tuple of + paddle.float32 scores for next token (n_vocab) + and None + + """ + return paddle.to_tensor([1.0], place=x.place, dtype=x.dtype).expand(self.n), None + + def batch_score( + self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor + ) -> Tuple[paddle.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (paddle.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + return ( + paddle.to_tensor([1.0], place=xs.place, dtype=xs.dtype).expand( + ys.shape[0], self.n + ), + None, + ) diff --git a/deepspeech/decoders/scores/ngram.py b/deepspeech/decoders/scores/ngram.py new file mode 100644 index 000000000..9809427bd --- /dev/null +++ b/deepspeech/decoders/scores/ngram.py @@ -0,0 +1,102 @@ +"""Ngram lm implement.""" + +from abc import ABC + +import kenlm +import paddle + +from .scorer_interface import BatchScorerInterface +from .scorer_interface import PartialScorerInterface + + +class Ngrambase(ABC): + """Ngram base implemented through ScorerInterface.""" + + def __init__(self, ngram_model, token_list): + """Initialize Ngrambase. + + Args: + ngram_model: ngram model path + token_list: token list from dict or model.json + + """ + self.chardict = [x if x != "" else "" for x in token_list] + self.charlen = len(self.chardict) + self.lm = kenlm.LanguageModel(ngram_model) + self.tmpkenlmstate = kenlm.State() + + def init_state(self, x): + """Initialize tmp state.""" + state = kenlm.State() + self.lm.NullContextWrite(state) + return state + + def score_partial_(self, y, next_token, state, x): + """Score interface for both full and partial scorer. + + Args: + y: previous char + next_token: next token need to be score + state: previous state + x: encoded feature + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + out_state = kenlm.State() + ys = self.chardict[y[-1]] if y.shape[0] > 1 else "" + self.lm.BaseScore(state, ys, out_state) + scores = paddle.empty_like(next_token, dtype=x.dtype) + for i, j in enumerate(next_token): + scores[i] = self.lm.BaseScore( + out_state, self.chardict[j], self.tmpkenlmstate + ) + return scores, out_state + + +class NgramFullScorer(Ngrambase, BatchScorerInterface): + """Fullscorer for ngram.""" + + def score(self, y, state, x): + """Score interface for both full and partial scorer. + + Args: + y: previous char + state: previous state + x: encoded feature + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + return self.score_partial_(y, paddle.to_tensor(range(self.charlen)), state, x) + + +class NgramPartScorer(Ngrambase, PartialScorerInterface): + """Partialscorer for ngram.""" + + def score_partial(self, y, next_token, state, x): + """Score interface for both full and partial scorer. + + Args: + y: previous char + next_token: next token need to be score + state: previous state + x: encoded feature + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + return self.score_partial_(y, next_token, state, x) + + def select_state(self, state, i): + """Empty select state for scorer interface.""" + return state diff --git a/deepspeech/decoders/scores/score_interface.py b/deepspeech/decoders/scores/score_interface.py new file mode 100644 index 000000000..c52f8d190 --- /dev/null +++ b/deepspeech/decoders/scores/score_interface.py @@ -0,0 +1,188 @@ +"""Scorer interface module.""" + +from typing import Any +from typing import List +from typing import Tuple + +import paddle +import warnings + + +class ScorerInterface: + """Scorer interface for beam search. + + The scorer performs scoring of the all tokens in vocabulary. + + Examples: + * Search heuristics + * :class:`scorers.length_bonus.LengthBonus` + * Decoder networks of the sequence-to-sequence models + * :class:`transformer.decoder.Decoder` + * :class:`rnn.decoders.Decoder` + * Neural language models + * :class:`lm.transformer.TransformerLM` + * :class:`lm.default.DefaultRNNLM` + * :class:`lm.seq_rnn.SequentialRNNLM` + + """ + + def init_state(self, x: paddle.Tensor) -> Any: + """Get an initial state for decoding (optional). + + Args: + x (paddle.Tensor): The encoded feature tensor + + Returns: initial state + + """ + return None + + def select_state(self, state: Any, i: int, new_id: int = None) -> Any: + """Select state with relative ids in the main beam search. + + Args: + state: Decoder state for prefix tokens + i (int): Index to select a state in the main beam search + new_id (int): New label index to select a state if necessary + + Returns: + state: pruned state + + """ + return None if state is None else state[i] + + def score( + self, y: paddle.Tensor, state: Any, x: paddle.Tensor + ) -> Tuple[paddle.Tensor, Any]: + """Score new token (required). + + Args: + y (paddle.Tensor): 1D paddle.int64 prefix tokens. + state: Scorer state for prefix tokens + x (paddle.Tensor): The encoder feature that generates ys. + + Returns: + tuple[paddle.Tensor, Any]: Tuple of + scores for next token that has a shape of `(n_vocab)` + and next state for ys + + """ + raise NotImplementedError + + def final_score(self, state: Any) -> float: + """Score eos (optional). + + Args: + state: Scorer state for prefix tokens + + Returns: + float: final score + + """ + return 0.0 + + +class BatchScorerInterface(ScorerInterface): + """Batch scorer interface.""" + + def batch_init_state(self, x: paddle.Tensor) -> Any: + """Get an initial state for decoding (optional). + + Args: + x (paddle.Tensor): The encoded feature tensor + + Returns: initial state + + """ + return self.init_state(x) + + def batch_score( + self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor + ) -> Tuple[paddle.Tensor, List[Any]]: + """Score new token batch (required). + + Args: + ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (paddle.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + warnings.warn( + "{} batch score is implemented through for loop not parallelized".format( + self.__class__.__name__ + ) + ) + scores = list() + outstates = list() + for i, (y, state, x) in enumerate(zip(ys, states, xs)): + score, outstate = self.score(y, state, x) + outstates.append(outstate) + scores.append(score) + scores = paddle.cat(scores, 0).view(ys.shape[0], -1) + return scores, outstates + + +class PartialScorerInterface(ScorerInterface): + """Partial scorer interface for beam search. + + The partial scorer performs scoring when non-partial scorer finished scoring, + and receives pre-pruned next tokens to score because it is too heavy to score + all the tokens. + + Examples: + * Prefix search for connectionist-temporal-classification models + * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` + + """ + + def score_partial( + self, y: paddle.Tensor, next_tokens: paddle.Tensor, state: Any, x: paddle.Tensor + ) -> Tuple[paddle.Tensor, Any]: + """Score new token (required). + + Args: + y (paddle.Tensor): 1D prefix token + next_tokens (paddle.Tensor): paddle.int64 next token to score + state: decoder state for prefix tokens + x (paddle.Tensor): The encoder feature that generates ys + + Returns: + tuple[paddle.Tensor, Any]: + Tuple of a score tensor for y that has a shape `(len(next_tokens),)` + and next state for ys + + """ + raise NotImplementedError + + +class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface): + """Batch partial scorer interface for beam search.""" + + def batch_score_partial( + self, + ys: paddle.Tensor, + next_tokens: paddle.Tensor, + states: List[Any], + xs: paddle.Tensor, + ) -> Tuple[paddle.Tensor, Any]: + """Score new token (required). + + Args: + ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen). + next_tokens (paddle.Tensor): paddle.int64 tokens to score (n_batch, n_token). + states (List[Any]): Scorer states for prefix tokens. + xs (paddle.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[paddle.Tensor, Any]: + Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` + and next states for ys + """ + raise NotImplementedError