You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/models/u2_st/u2_st.py

686 lines
26 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""U2 ASR Model
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf)
"""
import time
from typing import Dict
from typing import Optional
from typing import Tuple
import paddle
from paddle import jit
from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
from paddlespeech.s2t.modules.loss import LabelSmoothingLoss
from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"]
logger = Log(__name__).getlog()
class U2STBaseModel(nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,
st_decoder: TransformerDecoder,
decoder: TransformerDecoder=None,
ctc: CTCDecoderBase=None,
ctc_weight: float=0.0,
asr_weight: float=0.0,
ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0,
length_normalized_loss: bool=False,
**kwargs):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.asr_weight = asr_weight
self.encoder = encoder
self.st_decoder = st_decoder
self.decoder = decoder
self.ctc = ctc
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss, )
def forward(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
asr_text: paddle.Tensor=None,
asr_text_lengths: paddle.Tensor=None,
) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[
paddle.Tensor]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
Returns:
total_loss, attention_loss, ctc_loss
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# 1. Encoder
start = time.time()
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}")
encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
# 2a. ST-decoder branch
start = time.time()
loss_st, acc_st = self._calc_st_loss(encoder_out, encoder_mask, text,
text_lengths)
decoder_time = time.time() - start
loss_asr_att = None
loss_asr_ctc = None
# 2b. ASR Attention-decoder branch
if self.asr_weight > 0.:
if self.ctc_weight != 1.0:
start = time.time()
loss_asr_att, acc_att = self._calc_att_loss(
encoder_out, encoder_mask, asr_text, asr_text_lengths)
decoder_time = time.time() - start
# 2c. CTC branch
if self.ctc_weight != 0.0:
start = time.time()
loss_asr_ctc = self.ctc(encoder_out, encoder_out_lens, asr_text,
asr_text_lengths)
ctc_time = time.time() - start
if loss_asr_ctc is None:
loss_asr = loss_asr_att
elif loss_asr_att is None:
loss_asr = loss_asr_ctc
else:
loss_asr = self.ctc_weight * loss_asr_ctc + (1 - self.ctc_weight
) * loss_asr_att
loss = self.asr_weight * loss_asr + (1 - self.asr_weight) * loss_st
else:
loss = loss_st
return loss, loss_st, loss_asr_att, loss_asr_ctc
def _calc_st_loss(
self,
encoder_out: paddle.Tensor,
encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id, )
return loss_att, acc_att
def _calc_att_loss(
self,
encoder_out: paddle.Tensor,
encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id, )
return loss_att, acc_att
def _forward_encoder(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encoder pass.
Args:
speech (paddle.Tensor): [B, Tmax, D]
speech_lengths (paddle.Tensor): [B]
decoding_chunk_size (int, optional): chuck size. Defaults to -1.
num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
simulate_streaming (bool, optional): streaming or not. Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
encoder hiddens mask (B, 1, Tmax).
"""
# Let's assume B = batch_size
# 1. Encoder
if simulate_streaming and decoding_chunk_size > 0:
encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(
speech,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
else:
encoder_out, encoder_mask = self.encoder(
speech,
speech_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
return encoder_out, encoder_mask
def translate(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
beam_size: int=10,
word_reward: float=0.0,
maxlenratio: float=0.5,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False, ) -> paddle.Tensor:
""" Apply beam search on attention decoder with length penalty
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
word_reward (float): word reward used in beam search
maxlenratio (float): max length ratio to bound the length of translated text
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
paddle.Tensor: decoding result, (batch, max_result_len)
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
assert speech.shape[0] == 1
device = speech.place
# Let's assume B = batch_size and N = beam_size
# 1. Encoder and init hypothesis
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = max(int(encoder_out.shape[1] * maxlenratio), 5)
hyp = {"score": 0.0, "yseq": [self.sos], "cache": None}
hyps = [hyp]
ended_hyps = []
cur_best_score = -float("inf")
cache = None
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
ys = paddle.ones((len(hyps), i), dtype=paddle.long)
if hyps[0]["cache"] is not None:
cache = [
paddle.ones(
(len(hyps), i - 1, hyp_cache.shape[-1]),
dtype=paddle.float32) for hyp_cache in hyps[0]["cache"]
]
for j, hyp in enumerate(hyps):
ys[j, :] = paddle.to_tensor(hyp["yseq"])
if hyps[0]["cache"] is not None:
for k in range(len(cache)):
cache[k][j] = hyps[j]["cache"][k]
ys_mask = subsequent_mask(i).unsqueeze(0).to(device)
logp, cache = self.st_decoder.forward_one_step(
encoder_out.repeat(len(hyps), 1, 1),
encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
hyps_best_kept = []
for j, hyp in enumerate(hyps):
top_k_logp, top_k_index = logp[j:j + 1].topk(beam_size)
for b in range(beam_size):
new_hyp = {}
new_hyp["score"] = hyp["score"] + float(top_k_logp[0, b])
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"]
new_hyp["yseq"][len(hyp["yseq"])] = int(top_k_index[0, b])
new_hyp["cache"] = [cache_[j] for cache_ in cache]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: -x["score"])[:beam_size]
# sort and get nbest
hyps = hyps_best_kept
if i == maxlen:
for hyp in hyps:
hyp["yseq"].append(self.eos)
# finalize the ended hypotheses with word reward (by length)
remained_hyps = []
for hyp in hyps:
if hyp["yseq"][-1] == self.eos:
hyp["score"] += (i - 1) * word_reward
cur_best_score = max(cur_best_score, hyp["score"])
ended_hyps.append(hyp)
else:
# stop while guarantee the optimality
if hyp["score"] + maxlen * word_reward > cur_best_score:
remained_hyps.append(hyp)
# stop predition when there is no unended hypothesis
if not remained_hyps:
break
hyps = remained_hyps
# 3. Select best of best
best_hyp = max(ended_hyps, key=lambda x: x["score"])
return paddle.to_tensor([best_hyp["yseq"][1:]])
# @jit.to_static
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
# @jit.to_static
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
# @jit.to_static
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
# @jit.to_static
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@jit.to_static
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (paddle.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
`d_k * 2` for att key & value.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
paddle.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
paddle.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, T(?), d_k * 2)
depending on required_cache_size.
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return self.encoder.forward_chunk(xs, offset, required_cache_size,
att_cache, cnn_cache)
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
Returns:
paddle.Tensor: activation before ctc
"""
return self.ctc.log_softmax(xs)
@jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining, (B, T)
hyps_lens (paddle.Tensor): length of each hyp in hyps, (B)
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
Returns:
paddle.Tensor: decoder output, (B, L)
"""
assert encoder_out.shape[0] == 1
num_hyps = hyps.shape[0]
assert hyps_lens.shape[0] == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
return decoder_out
@paddle.no_grad()
def decode(self,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
beam_size: int,
word_reward: float=0.0,
maxlenratio: float=0.5,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False):
"""u2 decoding.
Args:
feats (Tensor): audio features, (B, T, D)
feats_lengths (Tensor): (B)
text_feature (TextFeaturizer): text feature object.
decoding_method (str): decoding mode, e.g.
'fullsentence',
'simultaneous'
beam_size (int): beam size for search
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here.
num_decoding_left_chunks (int, optional):
number of left chunks for decoding. Defaults to -1.
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
Raises:
ValueError: when not support decoding_method.
Returns:
List[List[int]]: transcripts.
"""
batch_size = feats.shape[0]
if decoding_method == 'fullsentence':
hyps = self.translate(
feats,
feats_lengths,
beam_size=beam_size,
word_reward=word_reward,
maxlenratio=maxlenratio,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)
hyps = [hyp.tolist() for hyp in hyps]
else:
raise ValueError(f"Not support decoding method: {decoding_method}")
res = [text_feature.defeaturize(hyp) for hyp in hyps]
return res
class U2STModel(U2STBaseModel):
def __init__(self, configs: dict):
vocab_size, encoder, decoder = U2STModel._init_from_config(configs)
if isinstance(decoder, Tuple):
st_decoder, asr_decoder, ctc = decoder
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
st_decoder=st_decoder,
decoder=asr_decoder,
ctc=ctc,
**configs['model_conf'])
else:
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
st_decoder=decoder,
**configs['model_conf'])
@classmethod
def _init_from_config(cls, configs: dict):
"""init sub module for model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean, dtype=paddle.float),
paddle.to_tensor(istd, dtype=paddle.float))
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
assert input_dim != 0, input_dim
assert vocab_size != 0, vocab_size
encoder_type = configs.get('encoder', 'transformer')
logger.info(f"U2 Encoder type: {encoder_type}")
if encoder_type == 'transformer':
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
elif encoder_type == 'conformer':
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
else:
raise ValueError(f"not support encoder type:{encoder_type}")
st_decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
asr_weight = configs['model_conf']['asr_weight']
logger.info(f"ASR Joint Training Weight: {asr_weight}")
if asr_weight > 0.:
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
# ctc decoder and ctc loss
model_conf = configs['model_conf']
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoderBase(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,
dropout_rate=dropout_rate,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=grad_norm_type)
return vocab_size, encoder, (st_decoder, decoder, ctc)
else:
return vocab_size, encoder, st_decoder
@classmethod
def from_config(cls, configs: dict):
"""init model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
nn.Layer: U2STModel
"""
model = cls(configs)
return model
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
with UpdateConfig(config):
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
model = cls.from_config(config)
if checkpoint_path:
infos = checkpoint.load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
class U2STInferModel(U2STModel):
def __init__(self, configs: dict):
super().__init__(configs)
def forward(self,
feats,
feats_lengths,
decoding_chunk_size=-1,
num_decoding_left_chunks=-1,
simulate_streaming=False):
"""export model function
Args:
feats (Tensor): [B, T, D]
feats_lengths (Tensor): [B]
Returns:
List[List[int]]: best path result
"""
return self.translate(
feats,
feats_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)