|
|
@ -31,6 +31,8 @@ from paddle import nn
|
|
|
|
|
|
|
|
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import pad_sequence
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import pad_sequence
|
|
|
|
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import reverse_pad_list
|
|
|
|
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import th_accuracy
|
|
|
|
from paddlespeech.audio.utils.tensor_utils import th_accuracy
|
|
|
|
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
|
|
|
|
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
|
|
|
|
from paddlespeech.s2t.frontend.utility import IGNORE_ID
|
|
|
|
from paddlespeech.s2t.frontend.utility import IGNORE_ID
|
|
|
@ -38,6 +40,7 @@ from paddlespeech.s2t.frontend.utility import load_cmvn
|
|
|
|
from paddlespeech.s2t.models.asr_interface import ASRInterface
|
|
|
|
from paddlespeech.s2t.models.asr_interface import ASRInterface
|
|
|
|
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
|
|
|
|
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
|
|
|
|
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
|
|
|
|
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
|
|
|
|
|
|
|
|
from paddlespeech.s2t.modules.decoder import BiTransformerDecoder
|
|
|
|
from paddlespeech.s2t.modules.decoder import TransformerDecoder
|
|
|
|
from paddlespeech.s2t.modules.decoder import TransformerDecoder
|
|
|
|
from paddlespeech.s2t.modules.encoder import ConformerEncoder
|
|
|
|
from paddlespeech.s2t.modules.encoder import ConformerEncoder
|
|
|
|
from paddlespeech.s2t.modules.encoder import TransformerEncoder
|
|
|
|
from paddlespeech.s2t.modules.encoder import TransformerEncoder
|
|
|
@ -69,6 +72,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
ctc: CTCDecoderBase,
|
|
|
|
ctc: CTCDecoderBase,
|
|
|
|
ctc_weight: float=0.5,
|
|
|
|
ctc_weight: float=0.5,
|
|
|
|
ignore_id: int=IGNORE_ID,
|
|
|
|
ignore_id: int=IGNORE_ID,
|
|
|
|
|
|
|
|
reverse_weight: float=0.0,
|
|
|
|
lsm_weight: float=0.0,
|
|
|
|
lsm_weight: float=0.0,
|
|
|
|
length_normalized_loss: bool=False,
|
|
|
|
length_normalized_loss: bool=False,
|
|
|
|
**kwargs):
|
|
|
|
**kwargs):
|
|
|
@ -82,6 +86,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
self.ignore_id = ignore_id
|
|
|
|
self.ignore_id = ignore_id
|
|
|
|
self.ctc_weight = ctc_weight
|
|
|
|
self.ctc_weight = ctc_weight
|
|
|
|
|
|
|
|
self.reverse_weight = reverse_weight
|
|
|
|
|
|
|
|
|
|
|
|
self.encoder = encoder
|
|
|
|
self.encoder = encoder
|
|
|
|
self.decoder = decoder
|
|
|
|
self.decoder = decoder
|
|
|
@ -171,12 +176,21 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
self.ignore_id)
|
|
|
|
self.ignore_id)
|
|
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
|
|
|
|
|
|
|
|
r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos,
|
|
|
|
|
|
|
|
self.ignore_id)
|
|
|
|
# 1. Forward decoder
|
|
|
|
# 1. Forward decoder
|
|
|
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
decoder_out, r_decoder_out, _ = self.decoder(
|
|
|
|
ys_in_lens)
|
|
|
|
encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad,
|
|
|
|
|
|
|
|
self.reverse_weight)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Compute attention loss
|
|
|
|
# 2. Compute attention loss
|
|
|
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
|
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
|
|
|
|
|
|
r_loss_att = paddle.to_tensor(0.0)
|
|
|
|
|
|
|
|
if self.reverse_weight > 0.0:
|
|
|
|
|
|
|
|
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
|
|
|
|
|
|
|
|
loss_att = loss_att * (1 - self.reverse_weight
|
|
|
|
|
|
|
|
) + r_loss_att * self.reverse_weight
|
|
|
|
acc_att = th_accuracy(
|
|
|
|
acc_att = th_accuracy(
|
|
|
|
decoder_out.view(-1, self.vocab_size),
|
|
|
|
decoder_out.view(-1, self.vocab_size),
|
|
|
|
ys_out_pad,
|
|
|
|
ys_out_pad,
|
|
|
@ -359,6 +373,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
# Let's assume B = batch_size
|
|
|
|
# Let's assume B = batch_size
|
|
|
|
# encoder_out: (B, maxlen, encoder_dim)
|
|
|
|
# encoder_out: (B, maxlen, encoder_dim)
|
|
|
|
# encoder_mask: (B, 1, Tmax)
|
|
|
|
# encoder_mask: (B, 1, Tmax)
|
|
|
|
|
|
|
|
|
|
|
|
encoder_out, encoder_mask = self._forward_encoder(
|
|
|
|
encoder_out, encoder_mask = self._forward_encoder(
|
|
|
|
speech, speech_lengths, decoding_chunk_size,
|
|
|
|
speech, speech_lengths, decoding_chunk_size,
|
|
|
|
num_decoding_left_chunks, simulate_streaming)
|
|
|
|
num_decoding_left_chunks, simulate_streaming)
|
|
|
@ -500,7 +515,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
ctc_weight: float=0.0,
|
|
|
|
ctc_weight: float=0.0,
|
|
|
|
simulate_streaming: bool=False, ) -> List[int]:
|
|
|
|
simulate_streaming: bool=False,
|
|
|
|
|
|
|
|
reverse_weight: float=0.0, ) -> List[int]:
|
|
|
|
""" Apply attention rescoring decoding, CTC prefix beam search
|
|
|
|
""" Apply attention rescoring decoding, CTC prefix beam search
|
|
|
|
is applied first to get nbest, then we resoring the nbest on
|
|
|
|
is applied first to get nbest, then we resoring the nbest on
|
|
|
|
attention decoder with corresponding encoder out
|
|
|
|
attention decoder with corresponding encoder out
|
|
|
@ -520,6 +536,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
assert speech.shape[0] == speech_lengths.shape[0]
|
|
|
|
assert speech.shape[0] == speech_lengths.shape[0]
|
|
|
|
assert decoding_chunk_size != 0
|
|
|
|
assert decoding_chunk_size != 0
|
|
|
|
|
|
|
|
if reverse_weight > 0.0:
|
|
|
|
|
|
|
|
# decoder should be a bitransformer decoder if reverse_weight > 0.0
|
|
|
|
|
|
|
|
assert hasattr(self.decoder, 'right_decoder')
|
|
|
|
device = speech.place
|
|
|
|
device = speech.place
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
# For attention rescoring we only support batch_size=1
|
|
|
|
# For attention rescoring we only support batch_size=1
|
|
|
@ -541,6 +560,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
hyp_content, place=device, dtype=paddle.long)
|
|
|
|
hyp_content, place=device, dtype=paddle.long)
|
|
|
|
hyp_list.append(hyp_content)
|
|
|
|
hyp_list.append(hyp_content)
|
|
|
|
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
|
|
|
|
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
|
|
|
|
|
|
|
|
ori_hyps_pad = hyps_pad
|
|
|
|
hyps_lens = paddle.to_tensor(
|
|
|
|
hyps_lens = paddle.to_tensor(
|
|
|
|
[len(hyp[0]) for hyp in hyps], place=device,
|
|
|
|
[len(hyp[0]) for hyp in hyps], place=device,
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
@ -550,8 +570,14 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
|
|
|
|
|
|
|
|
# ctc score in ln domain
|
|
|
|
# ctc score in ln domain
|
|
|
|
decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens,
|
|
|
|
# (beam_size, max_hyps_len, vocab_size)
|
|
|
|
encoder_out)
|
|
|
|
decoder_out, r_decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens,
|
|
|
|
|
|
|
|
encoder_out,reverse_weight )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
|
|
|
|
|
|
|
|
# conventional transformer decoder.
|
|
|
|
|
|
|
|
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
|
|
|
|
|
|
|
|
r_decoder_out = r_decoder_out.numpy()
|
|
|
|
|
|
|
|
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
best_score = -float('inf')
|
|
|
|
best_score = -float('inf')
|
|
|
@ -563,10 +589,18 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
score += decoder_out[i][j][w]
|
|
|
|
score += decoder_out[i][j][w]
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
logger.debug(
|
|
|
|
f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}"
|
|
|
|
f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if reverse_weight > 0:
|
|
|
|
|
|
|
|
r_score = 0.0
|
|
|
|
|
|
|
|
for j, w in enumerate(hyp[0]):
|
|
|
|
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
|
|
|
|
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
|
|
|
score = score * (1 - reverse_weight) + r_score * reverse_weight
|
|
|
|
|
|
|
|
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
score += hyp[1] * ctc_weight
|
|
|
|
score += hyp[1] * ctc_weight
|
|
|
|
if score > best_score:
|
|
|
|
if score > best_score:
|
|
|
@ -601,6 +635,17 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.eos
|
|
|
|
return self.eos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@jit.to_static(property=True)
|
|
|
|
|
|
|
|
def is_bidirectional_decoder(self) -> bool:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
paddle.Tensor: decoder output
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if hasattr(self.decoder, 'right_decoder'):
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# @jit.to_static
|
|
|
|
# @jit.to_static
|
|
|
|
def forward_encoder_chunk(
|
|
|
|
def forward_encoder_chunk(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
@ -660,7 +705,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
hyps: paddle.Tensor,
|
|
|
|
hyps: paddle.Tensor,
|
|
|
|
hyps_lens: paddle.Tensor,
|
|
|
|
hyps_lens: paddle.Tensor,
|
|
|
|
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
|
|
|
|
encoder_out: paddle.Tensor,
|
|
|
|
|
|
|
|
reverse_weight: float=0.0, ) -> paddle.Tensor:
|
|
|
|
""" Export interface for c++ call, forward decoder with multiple
|
|
|
|
""" Export interface for c++ call, forward decoder with multiple
|
|
|
|
hypothesis from ctc prefix beam search and one encoder output
|
|
|
|
hypothesis from ctc prefix beam search and one encoder output
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -678,11 +724,22 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
# (B, 1, T)
|
|
|
|
# (B, 1, T)
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
|
|
|
|
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# input for right to left decoder
|
|
|
|
|
|
|
|
# this hyps_lens has count <sos> token, we need minus it.
|
|
|
|
|
|
|
|
r_hyps_lens = hyps_lens - 1
|
|
|
|
|
|
|
|
# this hyps has included <sos> token, so it should be
|
|
|
|
|
|
|
|
# convert the original hyps.
|
|
|
|
|
|
|
|
r_hyps = hyps[:, 1:]
|
|
|
|
# (num_hyps, max_hyps_len, vocab_size)
|
|
|
|
# (num_hyps, max_hyps_len, vocab_size)
|
|
|
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
|
|
|
|
|
|
|
|
hyps_lens)
|
|
|
|
r_hyps = st_reverse_pad_list(r_hyps, r_hyps_lens, self.sos, self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoder_out, r_decoder_out, _ = self.decoder(
|
|
|
|
|
|
|
|
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight)
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
return decoder_out
|
|
|
|
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
|
|
|
|
|
|
|
|
return decoder_out, r_decoder_out
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
@paddle.no_grad()
|
|
|
|
def decode(self,
|
|
|
|
def decode(self,
|
|
|
@ -694,7 +751,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
ctc_weight: float=0.0,
|
|
|
|
ctc_weight: float=0.0,
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
simulate_streaming: bool=False):
|
|
|
|
simulate_streaming: bool=False,
|
|
|
|
|
|
|
|
reverse_weight: float=0.0):
|
|
|
|
"""u2 decoding.
|
|
|
|
"""u2 decoding.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -766,7 +824,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
|
|
|
ctc_weight=ctc_weight,
|
|
|
|
ctc_weight=ctc_weight,
|
|
|
|
simulate_streaming=simulate_streaming)
|
|
|
|
simulate_streaming=simulate_streaming,
|
|
|
|
|
|
|
|
reverse_weight=reverse_weight)
|
|
|
|
hyps = [hyp]
|
|
|
|
hyps = [hyp]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Not support decoding method: {decoding_method}")
|
|
|
|
raise ValueError(f"Not support decoding method: {decoding_method}")
|
|
|
@ -803,7 +862,6 @@ class U2Model(U2DecodeModel):
|
|
|
|
with DefaultInitializerContext(init_type):
|
|
|
|
with DefaultInitializerContext(init_type):
|
|
|
|
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(
|
|
|
|
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(
|
|
|
|
configs)
|
|
|
|
configs)
|
|
|
|
|
|
|
|
|
|
|
|
super().__init__(
|
|
|
|
super().__init__(
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
encoder=encoder,
|
|
|
|
encoder=encoder,
|
|
|
@ -853,10 +911,20 @@ class U2Model(U2DecodeModel):
|
|
|
|
raise ValueError(f"not support encoder type:{encoder_type}")
|
|
|
|
raise ValueError(f"not support encoder type:{encoder_type}")
|
|
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
decoder_type = configs.get('decoder', 'transformer')
|
|
|
|
|
|
|
|
logger.debug(f"U2 Decoder type: {decoder_type}")
|
|
|
|
|
|
|
|
if decoder_type == 'transformer':
|
|
|
|
decoder = TransformerDecoder(vocab_size,
|
|
|
|
decoder = TransformerDecoder(vocab_size,
|
|
|
|
encoder.output_size(),
|
|
|
|
encoder.output_size(),
|
|
|
|
**configs['decoder_conf'])
|
|
|
|
**configs['decoder_conf'])
|
|
|
|
|
|
|
|
elif decoder_type == 'bitransformer':
|
|
|
|
|
|
|
|
assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0
|
|
|
|
|
|
|
|
assert configs['decoder_conf']['r_num_blocks'] > 0
|
|
|
|
|
|
|
|
decoder = BiTransformerDecoder(vocab_size,
|
|
|
|
|
|
|
|
encoder.output_size(),
|
|
|
|
|
|
|
|
**configs['decoder_conf'])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError(f"not support decoder type:{decoder_type}")
|
|
|
|
# ctc decoder and ctc loss
|
|
|
|
# ctc decoder and ctc loss
|
|
|
|
model_conf = configs.get('model_conf', dict())
|
|
|
|
model_conf = configs.get('model_conf', dict())
|
|
|
|
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
|
|
|
|
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
|
|
|
|