|
|
@ -12,7 +12,7 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
"""U2 ASR Model
|
|
|
|
"""U2 ASR Model
|
|
|
|
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
|
|
|
|
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
|
|
|
|
(https://arxiv.org/pdf/2012.05481.pdf)
|
|
|
|
(https://arxiv.org/pdf/2012.05481.pdf)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
@ -83,7 +83,7 @@ class U2BaseModel(nn.Module):
|
|
|
|
# cnn_module_kernel=15,
|
|
|
|
# cnn_module_kernel=15,
|
|
|
|
# activation_type='swish',
|
|
|
|
# activation_type='swish',
|
|
|
|
# pos_enc_layer_type='rel_pos',
|
|
|
|
# pos_enc_layer_type='rel_pos',
|
|
|
|
# selfattention_layer_type='rel_selfattn',
|
|
|
|
# selfattention_layer_type='rel_selfattn',
|
|
|
|
))
|
|
|
|
))
|
|
|
|
# decoder related
|
|
|
|
# decoder related
|
|
|
|
default.decoder = 'transformer'
|
|
|
|
default.decoder = 'transformer'
|
|
|
@ -244,8 +244,8 @@ class U2BaseModel(nn.Module):
|
|
|
|
simulate_streaming (bool, optional): streaming or not. Defaults to False.
|
|
|
|
simulate_streaming (bool, optional): streaming or not. Defaults to False.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
encoder hiddens (B, Tmax, D),
|
|
|
|
encoder hiddens (B, Tmax, D),
|
|
|
|
encoder hiddens mask (B, 1, Tmax).
|
|
|
|
encoder hiddens mask (B, 1, Tmax).
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Let's assume B = batch_size
|
|
|
|
# Let's assume B = batch_size
|
|
|
@ -399,6 +399,7 @@ class U2BaseModel(nn.Module):
|
|
|
|
assert speech.shape[0] == speech_lengths.shape[0]
|
|
|
|
assert speech.shape[0] == speech_lengths.shape[0]
|
|
|
|
assert decoding_chunk_size != 0
|
|
|
|
assert decoding_chunk_size != 0
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
# Let's assume B = batch_size
|
|
|
|
# Let's assume B = batch_size
|
|
|
|
# encoder_out: (B, maxlen, encoder_dim)
|
|
|
|
# encoder_out: (B, maxlen, encoder_dim)
|
|
|
|
# encoder_mask: (B, 1, Tmax)
|
|
|
|
# encoder_mask: (B, 1, Tmax)
|
|
|
@ -410,10 +411,12 @@ class U2BaseModel(nn.Module):
|
|
|
|
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
|
|
|
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
|
|
|
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
|
|
|
|
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
|
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
|
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
|
|
|
|
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
|
|
|
|
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
|
|
|
|
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
|
|
|
|
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
|
|
|
|
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
|
|
|
|
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
|
|
|
|
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
|
|
|
|
|
|
|
|
|
|
|
|
hyps = [hyp.tolist() for hyp in topk_index]
|
|
|
|
hyps = [hyp.tolist() for hyp in topk_index]
|
|
|
|
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
|
|
|
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
|
|
|
return hyps
|
|
|
|
return hyps
|
|
|
@ -449,6 +452,7 @@ class U2BaseModel(nn.Module):
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
# For CTC prefix beam search, we only support batch_size=1
|
|
|
|
# For CTC prefix beam search, we only support batch_size=1
|
|
|
|
assert batch_size == 1
|
|
|
|
assert batch_size == 1
|
|
|
|
|
|
|
|
|
|
|
|
# Let's assume B = batch_size and N = beam_size
|
|
|
|
# Let's assume B = batch_size and N = beam_size
|
|
|
|
# 1. Encoder forward and get CTC score
|
|
|
|
# 1. Encoder forward and get CTC score
|
|
|
|
encoder_out, encoder_mask = self._forward_encoder(
|
|
|
|
encoder_out, encoder_mask = self._forward_encoder(
|
|
|
@ -458,7 +462,9 @@ class U2BaseModel(nn.Module):
|
|
|
|
maxlen = encoder_out.size(1)
|
|
|
|
maxlen = encoder_out.size(1)
|
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
|
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
|
|
|
|
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
|
|
|
|
|
|
|
|
# blank_ending_score and none_blank_ending_score in ln domain
|
|
|
|
cur_hyps = [(tuple(), (0.0, -float('inf')))]
|
|
|
|
cur_hyps = [(tuple(), (0.0, -float('inf')))]
|
|
|
|
# 2. CTC beam search step by step
|
|
|
|
# 2. CTC beam search step by step
|
|
|
|
for t in range(0, maxlen):
|
|
|
|
for t in range(0, maxlen):
|
|
|
@ -498,6 +504,7 @@ class U2BaseModel(nn.Module):
|
|
|
|
key=lambda x: log_add(list(x[1])),
|
|
|
|
key=lambda x: log_add(list(x[1])),
|
|
|
|
reverse=True)
|
|
|
|
reverse=True)
|
|
|
|
cur_hyps = next_hyps[:beam_size]
|
|
|
|
cur_hyps = next_hyps[:beam_size]
|
|
|
|
|
|
|
|
|
|
|
|
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
|
|
|
|
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
|
|
|
|
return hyps, encoder_out
|
|
|
|
return hyps, encoder_out
|
|
|
|
|
|
|
|
|
|
|
@ -561,12 +568,13 @@ class U2BaseModel(nn.Module):
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
batch_size = speech.shape[0]
|
|
|
|
# For attention rescoring we only support batch_size=1
|
|
|
|
# For attention rescoring we only support batch_size=1
|
|
|
|
assert batch_size == 1
|
|
|
|
assert batch_size == 1
|
|
|
|
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
|
|
|
|
|
|
|
|
|
|
|
|
# len(hyps) = beam_size, encoder_out: (1, maxlen, encoder_dim)
|
|
|
|
hyps, encoder_out = self._ctc_prefix_beam_search(
|
|
|
|
hyps, encoder_out = self._ctc_prefix_beam_search(
|
|
|
|
speech, speech_lengths, beam_size, decoding_chunk_size,
|
|
|
|
speech, speech_lengths, beam_size, decoding_chunk_size,
|
|
|
|
num_decoding_left_chunks, simulate_streaming)
|
|
|
|
num_decoding_left_chunks, simulate_streaming)
|
|
|
|
|
|
|
|
|
|
|
|
assert len(hyps) == beam_size
|
|
|
|
assert len(hyps) == beam_size
|
|
|
|
|
|
|
|
|
|
|
|
hyps_pad = pad_sequence([
|
|
|
|
hyps_pad = pad_sequence([
|
|
|
|
paddle.to_tensor(hyp[0], place=device, dtype=paddle.long)
|
|
|
|
paddle.to_tensor(hyp[0], place=device, dtype=paddle.long)
|
|
|
|
for hyp in hyps
|
|
|
|
for hyp in hyps
|
|
|
@ -576,23 +584,28 @@ class U2BaseModel(nn.Module):
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
|
|
|
|
|
|
|
|
encoder_out = encoder_out.repeat(beam_size, 1, 1)
|
|
|
|
encoder_out = encoder_out.repeat(beam_size, 1, 1)
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool)
|
|
|
|
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool)
|
|
|
|
decoder_out, _ = self.decoder(
|
|
|
|
decoder_out, _ = self.decoder(
|
|
|
|
encoder_out, encoder_mask, hyps_pad,
|
|
|
|
encoder_out, encoder_mask, hyps_pad,
|
|
|
|
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
|
|
|
|
# ctc score in ln domain
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
|
|
|
|
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
best_score = -float('inf')
|
|
|
|
best_score = -float('inf')
|
|
|
|
best_index = 0
|
|
|
|
best_index = 0
|
|
|
|
|
|
|
|
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
|
|
|
|
for i, hyp in enumerate(hyps):
|
|
|
|
for i, hyp in enumerate(hyps):
|
|
|
|
score = 0.0
|
|
|
|
score = 0.0
|
|
|
|
for j, w in enumerate(hyp[0]):
|
|
|
|
for j, w in enumerate(hyp[0]):
|
|
|
|
score += decoder_out[i][j][w]
|
|
|
|
score += decoder_out[i][j][w]
|
|
|
|
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
# add ctc score
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
score += hyp[1] * ctc_weight
|
|
|
|
score += hyp[1] * ctc_weight
|
|
|
|
if score > best_score:
|
|
|
|
if score > best_score:
|
|
|
|
best_score = score
|
|
|
|
best_score = score
|
|
|
@ -719,8 +732,8 @@ class U2BaseModel(nn.Module):
|
|
|
|
feats (Tenosr): audio features, (B, T, D)
|
|
|
|
feats (Tenosr): audio features, (B, T, D)
|
|
|
|
feats_lengths (Tenosr): (B)
|
|
|
|
feats_lengths (Tenosr): (B)
|
|
|
|
text_feature (TextFeaturizer): text feature object.
|
|
|
|
text_feature (TextFeaturizer): text feature object.
|
|
|
|
decoding_method (str): decoding mode, e.g.
|
|
|
|
decoding_method (str): decoding mode, e.g.
|
|
|
|
'attention', 'ctc_greedy_search',
|
|
|
|
'attention', 'ctc_greedy_search',
|
|
|
|
'ctc_prefix_beam_search', 'attention_rescoring'
|
|
|
|
'ctc_prefix_beam_search', 'attention_rescoring'
|
|
|
|
lang_model_path (str): lm path.
|
|
|
|
lang_model_path (str): lm path.
|
|
|
|
beam_alpha (float): lm weight.
|
|
|
|
beam_alpha (float): lm weight.
|
|
|
@ -728,19 +741,19 @@ class U2BaseModel(nn.Module):
|
|
|
|
beam_size (int): beam size for search
|
|
|
|
beam_size (int): beam size for search
|
|
|
|
cutoff_prob (float): for prune.
|
|
|
|
cutoff_prob (float): for prune.
|
|
|
|
cutoff_top_n (int): for prune.
|
|
|
|
cutoff_top_n (int): for prune.
|
|
|
|
num_processes (int):
|
|
|
|
num_processes (int):
|
|
|
|
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
|
|
|
|
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
|
|
|
|
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
|
|
|
|
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
|
|
|
|
<0: for decoding, use full chunk.
|
|
|
|
<0: for decoding, use full chunk.
|
|
|
|
>0: for decoding, use fixed chunk size as set.
|
|
|
|
>0: for decoding, use fixed chunk size as set.
|
|
|
|
0: used for training, it's prohibited here.
|
|
|
|
0: used for training, it's prohibited here.
|
|
|
|
num_decoding_left_chunks (int, optional):
|
|
|
|
num_decoding_left_chunks (int, optional):
|
|
|
|
number of left chunks for decoding. Defaults to -1.
|
|
|
|
number of left chunks for decoding. Defaults to -1.
|
|
|
|
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
|
|
|
|
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
Raises:
|
|
|
|
ValueError: when not support decoding_method.
|
|
|
|
ValueError: when not support decoding_method.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
List[List[int]]: transcripts.
|
|
|
|
List[List[int]]: transcripts.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -821,7 +834,7 @@ class U2Model(U2BaseModel):
|
|
|
|
ValueError: raise when using not support encoder type.
|
|
|
|
ValueError: raise when using not support encoder type.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
|
|
|
|
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if configs['cmvn_file'] is not None:
|
|
|
|
if configs['cmvn_file'] is not None:
|
|
|
|
mean, istd = load_cmvn(configs['cmvn_file'],
|
|
|
|
mean, istd = load_cmvn(configs['cmvn_file'],
|
|
|
|