format code

pull/890/head
Hui Zhang 3 years ago
parent 60e9790610
commit 10a2da6808

@ -10,4 +10,4 @@
* [Vectorized Beam Search for CTC-Attention-based Speech Recognition](https://www.isca-speech.org/archive/pdfs/interspeech_2019/seki19b_interspeech.pdf) * [Vectorized Beam Search for CTC-Attention-based Speech Recognition](https://www.isca-speech.org/archive/pdfs/interspeech_2019/seki19b_interspeech.pdf)
### Streaming Join CTC/ATT Beam Search ### Streaming Join CTC/ATT Beam Search
* [STREAMING TRANSFORMER ASR WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH](https://arxiv.org/abs/2006.14941) * [STREAMING TRANSFORMER ASR WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH](https://arxiv.org/abs/2006.14941)

@ -1 +1,14 @@
from .ctcdecoder import swig_wrapper # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ctcdecoder import swig_wrapper

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,5 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ScorerInterface implementation for CTC.""" """ScorerInterface implementation for CTC."""
import numpy as np import numpy as np
import paddle import paddle
@ -81,8 +93,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
prev_score, state = state prev_score, state = state
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
tscore = paddle.to_tensor( tscore = paddle.to_tensor(
presub_score - prev_score, place=x.place, dtype=x.dtype presub_score - prev_score, place=x.place, dtype=x.dtype)
)
return tscore, (presub_score, new_st) return tscore, (presub_score, new_st)
def batch_init_state(self, x: paddle.Tensor): def batch_init_state(self, x: paddle.Tensor):
@ -115,15 +126,9 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
""" """
batch_state = ( batch_state = (
( (paddle.stack([s[0] for s in state], axis=2),
paddle.stack([s[0] for s in state], axis=2), paddle.stack([s[1] for s in state]), state[0][2], state[0][3], )
paddle.stack([s[1] for s in state]), if state[0] is not None else None)
state[0][2],
state[0][3],
)
if state[0] is not None
else None
)
return self.impl(y, batch_state, ids) return self.impl(y, batch_state, ids)
def extend_prob(self, x: paddle.Tensor): def extend_prob(self, x: paddle.Tensor):

@ -1,11 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori) # Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import paddle
import numpy as np import numpy as np
import paddle
import six import six
@ -49,9 +46,10 @@ class CTCPrefixScorePD():
x[i, l:, blank] = 0 x[i, l:, blank] = 0
# Reshape input x # Reshape input x
xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O) xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) # (T,B,O) xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1,
self.odim) # (T,B,O)
self.x = paddle.stack([xn, xb]) # (2, T, B, O) self.x = paddle.stack([xn, xb]) # (2, T, B, O)
self.end_frames = paddle.to_tensor(xlens) - 1 # (B,) self.end_frames = paddle.to_tensor(xlens) - 1 # (B,)
# Setup CTC windowing # Setup CTC windowing
self.margin = margin self.margin = margin
@ -59,7 +57,7 @@ class CTCPrefixScorePD():
self.frame_ids = paddle.arange(self.input_length, dtype=self.dtype) self.frame_ids = paddle.arange(self.input_length, dtype=self.dtype)
# Base indices for index conversion # Base indices for index conversion
# B idx, hyp idx. shape (B*W, 1) # B idx, hyp idx. shape (B*W, 1)
self.idx_bh = None self.idx_bh = None
# B idx. shape (B,) # B idx. shape (B,)
self.idx_b = paddle.arange(self.batch) self.idx_b = paddle.arange(self.batch)
# B idx, O idx. shape (B, 1) # B idx, O idx. shape (B, 1)
@ -78,56 +76,59 @@ class CTCPrefixScorePD():
last_ids = [yi[-1] for yi in y] # last output label ids last_ids = [yi[-1] for yi in y] # last output label ids
n_bh = len(last_ids) # batch * hyps n_bh = len(last_ids) # batch * hyps
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0 self.scoring_num = scoring_ids.size(
-1) if scoring_ids is not None else 0
# prepare state info # prepare state info
if state is None: if state is None:
r_prev = paddle.full( r_prev = paddle.full(
(self.input_length, 2, self.batch, n_hyps), (self.input_length, 2, self.batch, n_hyps),
self.logzero, self.logzero,
dtype=self.dtype, dtype=self.dtype, ) # (T, 2, B, W)
) # (T, 2, B, W) r_prev[:, 1] = paddle.cumsum(self.x[0, :, :, self.blank],
r_prev[:, 1] = paddle.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2) 0).unsqueeze(2)
r_prev = r_prev.view(-1, 2, n_bh) # (T, 2, BW) r_prev = r_prev.view(-1, 2, n_bh) # (T, 2, BW)
s_prev = 0.0 # score s_prev = 0.0 # score
f_min_prev = 0 # eq. 22-23 f_min_prev = 0 # eq. 22-23
f_max_prev = 1 # eq. 22-23 f_max_prev = 1 # eq. 22-23
else: else:
r_prev, s_prev, f_min_prev, f_max_prev = state r_prev, s_prev, f_min_prev, f_max_prev = state
# select input dimensions for scoring # select input dimensions for scoring
if self.scoring_num > 0: if self.scoring_num > 0:
# (BW, O) # (BW, O)
scoring_idmap = paddle.full((n_bh, self.odim), -1, dtype=paddle.long) scoring_idmap = paddle.full(
(n_bh, self.odim), -1, dtype=paddle.long)
snum = self.scoring_num snum = self.scoring_num
if self.idx_bh is None or n_bh > len(self.idx_bh): if self.idx_bh is None or n_bh > len(self.idx_bh):
self.idx_bh = paddle.arange(n_bh).view(-1, 1) # (BW, 1) self.idx_bh = paddle.arange(n_bh).view(-1, 1) # (BW, 1)
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = paddle.arange(snum) scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = paddle.arange(snum)
scoring_idx = ( scoring_idx = (
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) # (BW,1) scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1,
).view(-1) # (BWO) 1) # (BW,1)
).view(-1) # (BWO)
# x_ shape (2, T, B*W, O) # x_ shape (2, T, B*W, O)
x_ = paddle.index_select( x_ = paddle.index_select(
self.x.view(2, -1, self.batch * self.odim), scoring_idx, 2 self.x.view(2, -1, self.batch * self.odim), scoring_idx,
).view(2, -1, n_bh, snum) 2).view(2, -1, n_bh, snum)
else: else:
scoring_ids = None scoring_ids = None
scoring_idmap = None scoring_idmap = None
snum = self.odim snum = self.odim
# x_ shape (2, T, B*W, O) # x_ shape (2, T, B*W, O)
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum) x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1,
n_bh, snum)
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch. # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
r = paddle.full( r = paddle.full(
(self.input_length, 2, n_bh, snum), (self.input_length, 2, n_bh, snum),
self.logzero, self.logzero,
dtype=self.dtype, dtype=self.dtype, )
)
if output_length == 0: if output_length == 0:
r[0, 0] = x_[0, 0] r[0, 0] = x_[0, 0]
r_sum = paddle.logsumexp(r_prev, 1) #(T,BW) r_sum = paddle.logsumexp(r_prev, 1) #(T,BW)
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) # (T, BW, O) log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) # (T, BW, O)
if scoring_ids is not None: if scoring_ids is not None:
for idx in range(n_bh): for idx in range(n_bh):
pos = scoring_idmap[idx, last_ids[idx]] pos = scoring_idmap[idx, last_ids[idx]]
@ -152,27 +153,30 @@ class CTCPrefixScorePD():
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for t in range(start, end): for t in range(start, end):
rp = r[t - 1] # (2 x BW x O') rp = r[t - 1] # (2 x BW x O')
rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
2, 2, n_bh, snum 2, 2, n_bh, snum) # (2,2,BW,O')
) # (2,2,BW,O')
r[t] = paddle.logsumexp(rr, 1) + x_[:, t] r[t] = paddle.logsumexp(rr, 1) + x_[:, t]
# compute log prefix probabilities log(psi) # compute log prefix probabilities log(psi)
log_phi_x = paddle.concat((log_phi[0].unsqueeze(0), log_phi[:-1]), axis=0) + x_[0] log_phi_x = paddle.concat(
(log_phi[0].unsqueeze(0), log_phi[:-1]), axis=0) + x_[0]
if scoring_ids is not None: if scoring_ids is not None:
log_psi = paddle.full((n_bh, self.odim), self.logzero, dtype=self.dtype) log_psi = paddle.full(
(n_bh, self.odim), self.logzero, dtype=self.dtype)
log_psi_ = paddle.logsumexp( log_psi_ = paddle.logsumexp(
paddle.concat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), axis=0), paddle.concat(
axis=0, (log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)),
) axis=0),
axis=0, )
for si in range(n_bh): for si in range(n_bh):
log_psi[si, scoring_ids[si]] = log_psi_[si] log_psi[si, scoring_ids[si]] = log_psi_[si]
else: else:
log_psi = paddle.logsumexp( log_psi = paddle.logsumexp(
paddle.concat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), axis=0), paddle.concat(
axis=0, (log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)),
) axis=0),
axis=0, )
for si in range(n_bh): for si in range(n_bh):
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si] log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
@ -193,16 +197,16 @@ class CTCPrefixScorePD():
# convert ids to BHO space # convert ids to BHO space
n_bh = len(s) n_bh = len(s)
n_hyps = n_bh // self.batch n_hyps = n_bh // self.batch
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1) vidx = (best_ids + (self.idx_b *
(n_hyps * self.odim)).view(-1, 1)).view(-1)
# select hypothesis scores # select hypothesis scores
s_new = paddle.index_select(s.view(-1), vidx, 0) s_new = paddle.index_select(s.view(-1), vidx, 0)
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim) s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
# convert ids to BHS space (S: scoring_num) # convert ids to BHS space (S: scoring_num)
if scoring_idmap is not None: if scoring_idmap is not None:
snum = self.scoring_num snum = self.scoring_num
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view( hyp_idx = (best_ids // self.odim +
-1 (self.idx_b * n_hyps).view(-1, 1)).view(-1)
)
label_ids = paddle.fmod(best_ids, self.odim).view(-1) label_ids = paddle.fmod(best_ids, self.odim).view(-1)
score_idx = scoring_idmap[hyp_idx, label_ids] score_idx = scoring_idmap[hyp_idx, label_ids]
score_idx[score_idx == -1] = 0 score_idx[score_idx == -1] = 0
@ -211,8 +215,7 @@ class CTCPrefixScorePD():
snum = self.odim snum = self.odim
# select forward probabilities # select forward probabilities
r_new = paddle.index_select(r.view(-1, 2, n_bh * snum), vidx, 2).view( r_new = paddle.index_select(r.view(-1, 2, n_bh * snum), vidx, 2).view(
-1, 2, n_bh -1, 2, n_bh)
)
return r_new, s_new, f_min, f_max return r_new, s_new, f_min, f_max
def extend_prob(self, x): def extend_prob(self, x):
@ -233,7 +236,7 @@ class CTCPrefixScorePD():
xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O) xn = x.transpose([1, 0, 2]) # (B, T, O) -> (T, B, O)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = paddle.stack([xn, xb]) # (2, T, B, O) self.x = paddle.stack([xn, xb]) # (2, T, B, O)
self.x[:, : tmp_x.shape[1], :, :] = tmp_x self.x[:, :tmp_x.shape[1], :, :] = tmp_x
self.input_length = x.size(1) self.input_length = x.size(1)
self.end_frames = paddle.to_tensor(xlens) - 1 self.end_frames = paddle.to_tensor(xlens) - 1
@ -254,12 +257,12 @@ class CTCPrefixScorePD():
r_prev_new = paddle.full( r_prev_new = paddle.full(
(self.input_length, 2), (self.input_length, 2),
self.logzero, self.logzero,
dtype=self.dtype, dtype=self.dtype, )
)
start = max(r_prev.shape[0], 1) start = max(r_prev.shape[0], 1)
r_prev_new[0:start] = r_prev r_prev_new[0:start] = r_prev
for t in range(start, self.input_length): for t in range(start, self.input_length):
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank] r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :,
self.blank]
return (r_prev_new, s_prev, f_min_prev, f_max_prev) return (r_prev_new, s_prev, f_min_prev, f_max_prev)
@ -279,7 +282,7 @@ class CTCPrefixScore():
self.blank = blank self.blank = blank
self.eos = eos self.eos = eos
self.input_length = len(x) self.input_length = len(x)
self.x = x # (T, O) self.x = x # (T, O)
def initial_state(self): def initial_state(self):
"""Obtain an initial CTC state """Obtain an initial CTC state
@ -318,12 +321,12 @@ class CTCPrefixScore():
r[output_length - 1] = self.logzero r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label # prepare forward probabilities for the last label
r_sum = self.xp.logaddexp( r_sum = self.xp.logaddexp(r_prev[:, 0],
r_prev[:, 0], r_prev[:, 1] r_prev[:, 1]) # log(r_t^n(g) + r_t^b(g))
) # log(r_t^n(g) + r_t^b(g))
last = y[-1] last = y[-1]
if output_length > 0 and last in cs: if output_length > 0 and last in cs:
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32) log_phi = self.xp.ndarray(
(self.input_length, len(cs)), dtype=np.float32)
for i in six.moves.range(len(cs)): for i in six.moves.range(len(cs)):
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1] log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
else: else:
@ -335,9 +338,8 @@ class CTCPrefixScore():
log_psi = r[start - 1, 0] log_psi = r[start - 1, 0]
for t in six.moves.range(start, self.input_length): for t in six.moves.range(start, self.input_length):
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t] r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
r[t, 1] = ( r[t, 1] = (self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) +
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank] self.x[t, self.blank])
)
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t]) log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
# get P(...eos|X) that ends with the prefix itself # get P(...eos|X) that ends with the prefix itself

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Length bonus module.""" """Length bonus module."""
from typing import Any from typing import Any
from typing import List from typing import List
@ -34,11 +47,13 @@ class LengthBonus(BatchScorerInterface):
and None and None
""" """
return paddle.to_tensor([1.0], place=x.place, dtype=x.dtype).expand(self.n), None return paddle.to_tensor(
[1.0], place=x.place, dtype=x.dtype).expand(self.n), None
def batch_score( def batch_score(self,
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor ys: paddle.Tensor,
) -> Tuple[paddle.Tensor, List[Any]]: states: List[Any],
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch. """Score new token batch.
Args: Args:
@ -53,9 +68,5 @@ class LengthBonus(BatchScorerInterface):
and next state list for ys. and next state list for ys.
""" """
return ( return (paddle.to_tensor([1.0], place=xs.place, dtype=xs.dtype).expand(
paddle.to_tensor([1.0], place=xs.place, dtype=xs.dtype).expand( ys.shape[0], self.n), None, )
ys.shape[0], self.n
),
None,
)

@ -1,5 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ngram lm implement.""" """Ngram lm implement."""
from abc import ABC from abc import ABC
import kenlm import kenlm
@ -51,9 +63,8 @@ class Ngrambase(ABC):
self.lm.BaseScore(state, ys, out_state) self.lm.BaseScore(state, ys, out_state)
scores = paddle.empty_like(next_token, dtype=x.dtype) scores = paddle.empty_like(next_token, dtype=x.dtype)
for i, j in enumerate(next_token): for i, j in enumerate(next_token):
scores[i] = self.lm.BaseScore( scores[i] = self.lm.BaseScore(out_state, self.chardict[j],
out_state, self.chardict[j], self.tmpkenlmstate self.tmpkenlmstate)
)
return scores, out_state return scores, out_state
@ -74,7 +85,8 @@ class NgramFullScorer(Ngrambase, BatchScorerInterface):
and next state list for ys. and next state list for ys.
""" """
return self.score_partial_(y, paddle.to_tensor(range(self.charlen)), state, x) return self.score_partial_(
y, paddle.to_tensor(range(self.charlen)), state, x)
class NgramPartScorer(Ngrambase, PartialScorerInterface): class NgramPartScorer(Ngrambase, PartialScorerInterface):

@ -1,11 +1,23 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Scorer interface module.""" """Scorer interface module."""
import warnings
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Tuple from typing import Tuple
import paddle import paddle
import warnings
class ScorerInterface: class ScorerInterface:
@ -37,7 +49,7 @@ class ScorerInterface:
""" """
return None return None
def select_state(self, state: Any, i: int, new_id: int = None) -> Any: def select_state(self, state: Any, i: int, new_id: int=None) -> Any:
"""Select state with relative ids in the main beam search. """Select state with relative ids in the main beam search.
Args: Args:
@ -51,9 +63,8 @@ class ScorerInterface:
""" """
return None if state is None else state[i] return None if state is None else state[i]
def score( def score(self, y: paddle.Tensor, state: Any,
self, y: paddle.Tensor, state: Any, x: paddle.Tensor x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
) -> Tuple[paddle.Tensor, Any]:
"""Score new token (required). """Score new token (required).
Args: Args:
@ -96,9 +107,10 @@ class BatchScorerInterface(ScorerInterface):
""" """
return self.init_state(x) return self.init_state(x)
def batch_score( def batch_score(self,
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor ys: paddle.Tensor,
) -> Tuple[paddle.Tensor, List[Any]]: states: List[Any],
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required). """Score new token batch (required).
Args: Args:
@ -114,10 +126,8 @@ class BatchScorerInterface(ScorerInterface):
""" """
warnings.warn( warnings.warn(
"{} batch score is implemented through for loop not parallelized".format( "{} batch score is implemented through for loop not parallelized".
self.__class__.__name__ format(self.__class__.__name__))
)
)
scores = list() scores = list()
outstates = list() outstates = list()
for i, (y, state, x) in enumerate(zip(ys, states, xs)): for i, (y, state, x) in enumerate(zip(ys, states, xs)):
@ -141,9 +151,11 @@ class PartialScorerInterface(ScorerInterface):
""" """
def score_partial( def score_partial(self,
self, y: paddle.Tensor, next_tokens: paddle.Tensor, state: Any, x: paddle.Tensor y: paddle.Tensor,
) -> Tuple[paddle.Tensor, Any]: next_tokens: paddle.Tensor,
state: Any,
x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
"""Score new token (required). """Score new token (required).
Args: Args:
@ -165,12 +177,11 @@ class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
"""Batch partial scorer interface for beam search.""" """Batch partial scorer interface for beam search."""
def batch_score_partial( def batch_score_partial(
self, self,
ys: paddle.Tensor, ys: paddle.Tensor,
next_tokens: paddle.Tensor, next_tokens: paddle.Tensor,
states: List[Any], states: List[Any],
xs: paddle.Tensor, xs: paddle.Tensor, ) -> Tuple[paddle.Tensor, Any]:
) -> Tuple[paddle.Tensor, Any]:
"""Score new token (required). """Score new token (required).
Args: Args:

@ -1,6 +1,20 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["end_detect"] __all__ = ["end_detect"]
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection. """End detection.
@ -20,11 +34,12 @@ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
for m in range(M): for m in range(M):
# get ended_hyps with their length is i - m # get ended_hyps with their length is i - m
hyp_length = i - m hyp_length = i - m
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] hyps_same_length = [
x for x in ended_hyps if len(x["yseq"]) == hyp_length
]
if len(hyps_same_length) > 0: if len(hyps_same_length) > 0:
best_hyp_same_length = sorted( best_hyp_same_length = sorted(
hyps_same_length, key=lambda x: x["score"], reverse=True hyps_same_length, key=lambda x: x["score"], reverse=True)[0]
)[0]
if best_hyp_same_length["score"] - best_hyp["score"] < D_end: if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
count += 1 count += 1

@ -125,7 +125,7 @@ class CTCDecoderBase(nn.Layer):
class CTCDecoder(CTCDecoderBase): class CTCDecoder(CTCDecoderBase):
def __init__(self,*args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# CTCDecoder LM Score handle # CTCDecoder LM Score handle
self._ext_scorer = None self._ext_scorer = None

Loading…
Cancel
Save