@ -16,10 +16,11 @@ Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recogni
( https : / / arxiv . org / pdf / 2012.05481 . pdf )
"""
import sys
from collections import defaultdict
import logging
from yacs . config import CfgNode
from typing import List , Optional , Tuple
from typing import List , Optional , Tuple , Dict
import paddle
from paddle import jit
@ -132,6 +133,7 @@ class U2BaseModel(nn.Module):
smoothing = lsm_weight ,
normalize_length = length_normalized_loss , )
@jit.export
def forward (
self ,
speech : paddle . Tensor ,
@ -158,7 +160,7 @@ class U2BaseModel(nn.Module):
encoder_out , encoder_mask = self . encoder ( speech , speech_lengths )
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask . squeeze ( 1 ) . astype ( paddle . int64 ) . sum (
encoder_out_lens = encoder_mask . squeeze ( 1 ) . c ast( paddle . int64 ) . sum (
1 ) #[B, 1, T] -> [B]
# 2a. Attention-decoder branch
@ -301,14 +303,15 @@ class U2BaseModel(nn.Module):
# log scale score
scores = paddle . to_tensor (
[ 0.0 ] + [ - float ( ' inf ' ) ] * ( beam_size - 1 ) , dtype = paddle . float )
scores = scores . to ( device ) . repeat ( [ batch_size ] ) . unsqueeze ( 1 ) . to (
scores = scores . to ( device ) . repeat ( batch_size ) . unsqueeze ( 1 ) . to (
device ) # (B*N, 1)
end_flag = paddle . zeros_like ( scores , dtype = paddle . bool ) # (B*N, 1)
cache : Optional [ List [ paddle . Tensor ] ] = None
# 2. Decoder forward step by step
for i in range ( 1 , maxlen + 1 ) :
# Stop if all batch and all beam produce eos
if end_flag . sum ( ) == running_size :
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag . cast ( paddle . int64 ) . sum ( ) == running_size :
break
# 2.1 Forward decoder step
@ -333,7 +336,7 @@ class U2BaseModel(nn.Module):
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index = paddle . arange ( batch_size ) . view ( - 1 , 1 ) . repeat (
[ 1 , beam_size ] ) # (B, N)
1 , beam_size ) # (B, N)
base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index . view ( - 1 ) + offset_k_index . view (
- 1 ) # (B*N)
@ -678,6 +681,108 @@ class U2BaseModel(nn.Module):
decoder_out = paddle . nn . functional . log_softmax ( decoder_out , dim = - 1 )
return decoder_out
@paddle.no_grad ( )
def decode ( self ,
feats : paddle . Tensor ,
feats_lengths : paddle . Tensor ,
text_feature : Dict [ str , int ] ,
decoding_method : str ,
lang_model_path : str ,
beam_alpha : float ,
beam_beta : float ,
beam_size : int ,
cutoff_prob : float ,
cutoff_top_n : int ,
num_processes : int ,
ctc_weight : float = 0.0 ,
decoding_chunk_size : int = - 1 ,
num_decoding_left_chunks : int = - 1 ,
simulate_streaming : bool = False ) :
""" u2 decoding.
Args :
feats ( Tenosr ) : audio features , ( B , T , D )
feats_lengths ( Tenosr ) : ( B )
text_feature ( TextFeaturizer ) : text feature object .
decoding_method ( str ) : decoding mode , e . g .
' attention ' , ' ctc_greedy_search ' ,
' ctc_prefix_beam_search ' , ' attention_rescoring '
lang_model_path ( str ) : lm path .
beam_alpha ( float ) : lm weight .
beam_beta ( float ) : length penalty .
beam_size ( int ) : beam size for search
cutoff_prob ( float ) : for prune .
cutoff_top_n ( int ) : for prune .
num_processes ( int ) :
ctc_weight ( float , optional ) : ctc weight for attention rescoring decode mode . Defaults to 0.0 .
decoding_chunk_size ( int , optional ) : decoding chunk size . Defaults to - 1.
< 0 : for decoding , use full chunk .
> 0 : for decoding , use fixed chunk size as set .
0 : used for training , it ' s prohibited here.
num_decoding_left_chunks ( int , optional ) :
number of left chunks for decoding . Defaults to - 1.
simulate_streaming ( bool , optional ) : simulate streaming inference . Defaults to False .
Raises :
ValueError : when not support decoding_method .
Returns :
List [ List [ int ] ] : transcripts .
"""
batch_size = feats . size ( 0 )
if decoding_method in [ ' ctc_prefix_beam_search ' ,
' attention_rescoring ' ] and batch_size > 1 :
logger . fatal (
f ' decoding mode { decoding_method } must be running with batch_size == 1 '
)
sys . exit ( 1 )
if decoding_method == ' attention ' :
hyps = self . recognize (
feats ,
feats_lengths ,
beam_size = beam_size ,
decoding_chunk_size = decoding_chunk_size ,
num_decoding_left_chunks = num_decoding_left_chunks ,
simulate_streaming = simulate_streaming )
hyps = [ hyp . tolist ( ) for hyp in hyps ]
elif decoding_method == ' ctc_greedy_search ' :
hyps = self . ctc_greedy_search (
feats ,
feats_lengths ,
decoding_chunk_size = decoding_chunk_size ,
num_decoding_left_chunks = num_decoding_left_chunks ,
simulate_streaming = simulate_streaming )
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif decoding_method == ' ctc_prefix_beam_search ' :
assert feats . size ( 0 ) == 1
hyp = self . ctc_prefix_beam_search (
feats ,
feats_lengths ,
beam_size ,
decoding_chunk_size = decoding_chunk_size ,
num_decoding_left_chunks = num_decoding_left_chunks ,
simulate_streaming = simulate_streaming )
hyps = [ hyp ]
elif decoding_method == ' attention_rescoring ' :
assert feats . size ( 0 ) == 1
hyp = self . attention_rescoring (
feats ,
feats_lengths ,
beam_size ,
decoding_chunk_size = decoding_chunk_size ,
num_decoding_left_chunks = num_decoding_left_chunks ,
ctc_weight = ctc_weight ,
simulate_streaming = simulate_streaming )
hyps = [ hyp ]
else :
raise ValueError ( f " Not support decoding method: { decoding_method } " )
res = [ text_feature . defeaturize ( hyp ) for hyp in hyps ]
return res
class U2Model ( U2BaseModel ) :
def __init__ ( self , configs : dict ) :
@ -779,14 +884,24 @@ class U2InferModel(U2Model):
def __init__ ( self , configs : dict ) :
super ( ) . __init__ ( configs )
def forward ( self , audio , audio_len ) :
def forward ( self ,
feats ,
feats_lengths ,
decoding_chunk_size = - 1 ,
num_decoding_left_chunks = - 1 ,
simulate_streaming = False ) :
""" export model function
Args :
audio ( Tensor ) : [ B , T , D ]
audio_len ( Tensor ) : [ B ]
feats ( Tensor ) : [ B , T , D ]
feats_lengths ( Tensor ) : [ B ]
Returns :
probs: probs after softmax
List[ List [ int ] ] : best path result
"""
raise NotImplementedError ( " U2Model infer " )
return self . ctc_greedy_search (
feats ,
feats_lengths ,
decoding_chunk_size = decoding_chunk_size ,
num_decoding_left_chunks = num_decoding_left_chunks ,
simulate_streaming = simulate_streaming )