|
|
|
@ -53,22 +53,21 @@ from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
__all__ = ['U2Model']
|
|
|
|
|
__all__ = ['U2TransformerModel', "U2ConformerModel"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class U2Model(nn.Module):
|
|
|
|
|
"""CTC-Attention hybrid Encoder-Decoder model"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
vocab_size: int,
|
|
|
|
|
encoder: TransformerEncoder,
|
|
|
|
|
decoder: TransformerDecoder,
|
|
|
|
|
ctc: CTCDecoder,
|
|
|
|
|
ctc_weight: float=0.5,
|
|
|
|
|
ignore_id: int=IGNORE_ID,
|
|
|
|
|
lsm_weight: float=0.0,
|
|
|
|
|
length_normalized_loss: bool=False, ):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
vocab_size: int,
|
|
|
|
|
encoder: TransformerEncoder,
|
|
|
|
|
decoder: TransformerDecoder,
|
|
|
|
|
ctc: CTCDecoder,
|
|
|
|
|
ctc_weight: float=0.5,
|
|
|
|
|
ignore_id: int=IGNORE_ID,
|
|
|
|
|
lsm_weight: float=0.0,
|
|
|
|
|
length_normalized_loss: bool=False):
|
|
|
|
|
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
|
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
@ -263,51 +262,54 @@ class U2Model(nn.Module):
|
|
|
|
|
# Stop if all batch and all beam produce eos
|
|
|
|
|
if end_flag.sum() == running_size:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# 2.1 Forward decoder step
|
|
|
|
|
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
|
|
|
|
|
running_size, 1, 1).to(device) # (B*N, i, i)
|
|
|
|
|
# logp: (B*N, vocab)
|
|
|
|
|
logp, cache = self.decoder.forward_one_step(
|
|
|
|
|
encoder_out, encoder_mask, hyps, hyps_mask, cache)
|
|
|
|
|
|
|
|
|
|
# 2.2 First beam prune: select topk best prob at current time
|
|
|
|
|
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
|
|
|
|
|
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
|
|
|
|
|
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
|
|
|
|
|
|
|
|
|
|
# 2.3 Seconde beam prune: select topk score with history
|
|
|
|
|
scores = scores + top_k_logp # (B*N, N), broadcast add
|
|
|
|
|
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
|
|
|
|
|
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
|
|
|
|
|
scores = scores.view(-1, 1) # (B*N, 1)
|
|
|
|
|
|
|
|
|
|
# 2.4. Compute base index in top_k_index,
|
|
|
|
|
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
|
|
|
|
|
# then find offset_k_index in top_k_index
|
|
|
|
|
base_k_index = torch.arange(
|
|
|
|
|
batch_size,
|
|
|
|
|
device=device).view(-1, 1).repeat([1, beam_size]) # (B, N)
|
|
|
|
|
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
|
|
|
|
|
[1, beam_size]) # (B, N)
|
|
|
|
|
base_k_index = base_k_index * beam_size * beam_size
|
|
|
|
|
best_k_index = base_k_index.view(-1) + offset_k_index.view(
|
|
|
|
|
-1) # (B*N)
|
|
|
|
|
|
|
|
|
|
# 2.5 Update best hyps
|
|
|
|
|
best_k_pred = torch.index_select(
|
|
|
|
|
top_k_index.view(-1), dim=-1, index=best_k_index) # (B*N)
|
|
|
|
|
best_k_pred = paddle.index_select(
|
|
|
|
|
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
|
|
|
|
|
best_hyps_index = best_k_index // beam_size
|
|
|
|
|
last_best_k_hyps = torch.index_select(
|
|
|
|
|
hyps, dim=0, index=best_hyps_index) # (B*N, i)
|
|
|
|
|
hyps = torch.cat(
|
|
|
|
|
last_best_k_hyps = paddle.index_select(
|
|
|
|
|
hyps, index=best_hyps_index, axis=0) # (B*N, i)
|
|
|
|
|
hyps = paddle.cat(
|
|
|
|
|
(last_best_k_hyps, best_k_pred.view(-1, 1)),
|
|
|
|
|
dim=1) # (B*N, i+1)
|
|
|
|
|
|
|
|
|
|
# 2.6 Update end flag
|
|
|
|
|
end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1)
|
|
|
|
|
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
|
|
|
|
|
|
|
|
|
|
# 3. Select best of best
|
|
|
|
|
scores = scores.view(batch_size, beam_size)
|
|
|
|
|
# TODO: length normalization
|
|
|
|
|
best_index = torch.argmax(scores, dim=-1).long()
|
|
|
|
|
best_hyps_index = best_index + torch.arange(
|
|
|
|
|
batch_size, dtype=torch.long, device=device) * beam_size
|
|
|
|
|
best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index)
|
|
|
|
|
best_index = paddle.argmax(scores, axis=-1).long() # (B)
|
|
|
|
|
best_hyps_index = best_index + paddle.arange(
|
|
|
|
|
batch_size, dtype=paddle.long) * beam_size
|
|
|
|
|
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
|
|
|
|
|
best_hyps = best_hyps[:, 1:]
|
|
|
|
|
return best_hyps
|
|
|
|
|
|
|
|
|
@ -346,8 +348,8 @@ class U2Model(nn.Module):
|
|
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
|
|
|
|
|
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
|
|
|
|
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
|
|
|
|
|
mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
|
|
|
|
|
topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen)
|
|
|
|
|
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
|
|
|
|
|
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
|
|
|
|
|
hyps = [hyp.tolist() for hyp in topk_index]
|
|
|
|
|
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
|
|
|
|
return hyps
|
|
|
|
@ -360,7 +362,7 @@ class U2Model(nn.Module):
|
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
|
simulate_streaming: bool=False,
|
|
|
|
|
) -> Tuple[List[List[int]], paddle.Tensor]:
|
|
|
|
|
blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
|
|
|
|
|
""" CTC prefix beam search inner implementation
|
|
|
|
|
Args:
|
|
|
|
|
speech (paddle.Tensor): (batch, max_len, feat_dim)
|
|
|
|
@ -374,7 +376,7 @@ class U2Model(nn.Module):
|
|
|
|
|
simulate_streaming (bool): whether do encoder forward in a
|
|
|
|
|
streaming fashion
|
|
|
|
|
Returns:
|
|
|
|
|
List[List[int]]: nbest results
|
|
|
|
|
List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood)
|
|
|
|
|
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
|
|
|
|
|
it will be used for rescoring in attention rescoring mode
|
|
|
|
|
"""
|
|
|
|
@ -406,7 +408,7 @@ class U2Model(nn.Module):
|
|
|
|
|
ps = logp[s].item()
|
|
|
|
|
for prefix, (pb, pnb) in cur_hyps:
|
|
|
|
|
last = prefix[-1] if len(prefix) > 0 else None
|
|
|
|
|
if s == 0: # blank
|
|
|
|
|
if s == blank_id: # blank
|
|
|
|
|
n_pb, n_pnb = next_hyps[prefix]
|
|
|
|
|
n_pb = log_add([n_pb, pb + ps, pnb + ps])
|
|
|
|
|
next_hyps[prefix] = (n_pb, n_pnb)
|
|
|
|
@ -491,7 +493,7 @@ class U2Model(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
assert speech.shape[0] == speech_lengths.shape[0]
|
|
|
|
|
assert decoding_chunk_size != 0
|
|
|
|
|
device = speech.device
|
|
|
|
|
device = speech.place
|
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
|
# For attention rescoring we only support batch_size=1
|
|
|
|
|
assert batch_size == 1
|
|
|
|
@ -502,22 +504,22 @@ class U2Model(nn.Module):
|
|
|
|
|
|
|
|
|
|
assert len(hyps) == beam_size
|
|
|
|
|
hyps_pad = pad_sequence([
|
|
|
|
|
paddle.tensor(hyp[0], device=device, dtype=torch.long)
|
|
|
|
|
paddle.to_tensor(hyp[0], place=device, dtype=paddle.long)
|
|
|
|
|
for hyp in hyps
|
|
|
|
|
], True, self.ignore_id) # (beam_size, max_hyps_len)
|
|
|
|
|
hyps_lens = paddle.tensor(
|
|
|
|
|
[len(hyp[0]) for hyp in hyps], device=device,
|
|
|
|
|
dtype=torch.long) # (beam_size,)
|
|
|
|
|
hyps_lens = paddle.to_tensor(
|
|
|
|
|
[len(hyp[0]) for hyp in hyps], place=device,
|
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
|
encoder_out = encoder_out.repeat(beam_size, 1, 1)
|
|
|
|
|
encoder_mask = torch.ones(
|
|
|
|
|
beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device)
|
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
|
beam_size, 1, encoder_out.size(1), dtype=paddle.bool)
|
|
|
|
|
decoder_out, _ = self.decoder(
|
|
|
|
|
encoder_out, encoder_mask, hyps_pad,
|
|
|
|
|
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
|
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
|
|
|
|
|
decoder_out = decoder_out.cpu().numpy()
|
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
|
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
|
best_score = -float('inf')
|
|
|
|
|
best_index = 0
|
|
|
|
@ -609,56 +611,83 @@ class U2Model(nn.Module):
|
|
|
|
|
hypothesis from ctc prefix beam search and one encoder output
|
|
|
|
|
Args:
|
|
|
|
|
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
|
|
|
|
|
pad sos at the begining
|
|
|
|
|
hyps_lens (paddle.Tensor): length of each hyp in hyps
|
|
|
|
|
encoder_out (paddle.Tensor): corresponding encoder output
|
|
|
|
|
pad sos at the begining, (B, T)
|
|
|
|
|
hyps_lens (paddle.Tensor): length of each hyp in hyps, (B)
|
|
|
|
|
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
|
|
|
|
|
Returns:
|
|
|
|
|
paddle.Tensor: decoder output
|
|
|
|
|
paddle.Tensor: decoder output, (B, L)
|
|
|
|
|
"""
|
|
|
|
|
assert encoder_out.size(0) == 1
|
|
|
|
|
num_hyps = hyps.size(0)
|
|
|
|
|
assert hyps_lens.size(0) == num_hyps
|
|
|
|
|
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
|
|
|
|
|
encoder_mask = torch.ones(
|
|
|
|
|
num_hyps,
|
|
|
|
|
1,
|
|
|
|
|
encoder_out.size(1),
|
|
|
|
|
dtype=torch.bool,
|
|
|
|
|
device=encoder_out.device)
|
|
|
|
|
decoder_out, _ = self.decoder(
|
|
|
|
|
encoder_out, encoder_mask, hyps,
|
|
|
|
|
hyps_lens) # (num_hyps, max_hyps_len, vocab_size)
|
|
|
|
|
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
|
|
|
|
|
# (B, 1, T)
|
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
|
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
|
|
|
|
|
# (num_hyps, max_hyps_len, vocab_size)
|
|
|
|
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
|
|
|
|
|
hyps_lens)
|
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
|
|
|
|
|
return decoder_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_asr_model(configs):
|
|
|
|
|
if configs['cmvn_file'] is not None:
|
|
|
|
|
mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn'])
|
|
|
|
|
global_cmvn = GlobalCMVN(
|
|
|
|
|
torch.from_numpy(mean).float(), torch.from_numpy(istd).float())
|
|
|
|
|
else:
|
|
|
|
|
global_cmvn = None
|
|
|
|
|
class U2TransformerModel(U2Model):
|
|
|
|
|
def __init__(configs: dict):
|
|
|
|
|
if configs['cmvn_file'] is not None:
|
|
|
|
|
mean, istd = load_cmvn(configs['cmvn_file'],
|
|
|
|
|
configs['is_json_cmvn'])
|
|
|
|
|
global_cmvn = GlobalCMVN(
|
|
|
|
|
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
|
|
|
|
|
else:
|
|
|
|
|
global_cmvn = None
|
|
|
|
|
|
|
|
|
|
input_dim = configs['input_dim']
|
|
|
|
|
vocab_size = configs['output_dim']
|
|
|
|
|
|
|
|
|
|
encoder_type = configs.get('encoder', 'transformer')
|
|
|
|
|
assert encoder_type == 'transformer'
|
|
|
|
|
encoder = TransformerEncoder(
|
|
|
|
|
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
|
|
|
|
|
|
|
|
|
|
decoder = TransformerDecoder(vocab_size,
|
|
|
|
|
encoder.output_size(),
|
|
|
|
|
**configs['decoder_conf'])
|
|
|
|
|
ctc = CTCDecoder(vocab_size, encoder.output_size())
|
|
|
|
|
|
|
|
|
|
self.__init__(
|
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
|
encoder=encoder,
|
|
|
|
|
decoder=decoder,
|
|
|
|
|
ctc=ctc,
|
|
|
|
|
**configs['model_conf'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class U2ConformerModel(U2Model):
|
|
|
|
|
def __init__(configs: dict):
|
|
|
|
|
if configs['cmvn_file'] is not None:
|
|
|
|
|
mean, istd = load_cmvn(configs['cmvn_file'],
|
|
|
|
|
configs['is_json_cmvn'])
|
|
|
|
|
global_cmvn = GlobalCMVN(
|
|
|
|
|
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
|
|
|
|
|
else:
|
|
|
|
|
global_cmvn = None
|
|
|
|
|
|
|
|
|
|
input_dim = configs['input_dim']
|
|
|
|
|
vocab_size = configs['output_dim']
|
|
|
|
|
input_dim = configs['input_dim']
|
|
|
|
|
vocab_size = configs['output_dim']
|
|
|
|
|
|
|
|
|
|
encoder_type = configs.get('encoder', 'conformer')
|
|
|
|
|
if encoder_type == 'conformer':
|
|
|
|
|
encoder_type = configs.get('encoder', 'conformer')
|
|
|
|
|
assert encoder_type == 'conformer'
|
|
|
|
|
encoder = ConformerEncoder(
|
|
|
|
|
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
|
|
|
|
|
else:
|
|
|
|
|
encoder = TransformerEncoder(
|
|
|
|
|
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
|
|
|
|
|
|
|
|
|
|
decoder = TransformerDecoder(vocab_size,
|
|
|
|
|
encoder.output_size(),
|
|
|
|
|
**configs['decoder_conf'])
|
|
|
|
|
ctc = CTCDecoder(vocab_size, encoder.output_size())
|
|
|
|
|
model = U2Model(
|
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
|
encoder=encoder,
|
|
|
|
|
decoder=decoder,
|
|
|
|
|
ctc=ctc,
|
|
|
|
|
**configs['model_conf'], )
|
|
|
|
|
return model
|
|
|
|
|
decoder = TransformerDecoder(vocab_size,
|
|
|
|
|
encoder.output_size(),
|
|
|
|
|
**configs['decoder_conf'])
|
|
|
|
|
ctc = CTCDecoder(vocab_size, encoder.output_size())
|
|
|
|
|
|
|
|
|
|
self.__init__(
|
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
|
encoder=encoder,
|
|
|
|
|
decoder=decoder,
|
|
|
|
|
ctc=ctc,
|
|
|
|
|
**configs['model_conf'])
|
|
|
|
|