@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import os
import os
from paddlespeech . s2t . utils . utility import log_add
from typing import Optional
from typing import Optional
from collections import defaultdict
import numpy as np
import numpy as np
import paddle
import paddle
from numpy import float32
from numpy import float32
@ -23,10 +24,14 @@ from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech . cli . asr . infer import model_alias
from paddlespeech . cli . asr . infer import model_alias
from paddlespeech . cli . asr . infer import pretrained_models
from paddlespeech . cli . asr . infer import pretrained_models
from paddlespeech . cli . log import logger
from paddlespeech . cli . log import logger
from paddlespeech . cli . utils import download_and_decompress
from paddlespeech . cli . utils import MODEL_HOME
from paddlespeech . cli . utils import MODEL_HOME
from paddlespeech . s2t . frontend . featurizer . text_featurizer import TextFeaturizer
from paddlespeech . s2t . frontend . featurizer . text_featurizer import TextFeaturizer
from paddlespeech . s2t . frontend . speech import SpeechSegment
from paddlespeech . s2t . frontend . speech import SpeechSegment
from paddlespeech . s2t . modules . ctc import CTCDecoder
from paddlespeech . s2t . modules . ctc import CTCDecoder
from paddlespeech . s2t . modules . mask import mask_finished_preds
from paddlespeech . s2t . modules . mask import mask_finished_scores
from paddlespeech . s2t . modules . mask import subsequent_mask
from paddlespeech . s2t . transform . transformation import Transformation
from paddlespeech . s2t . transform . transformation import Transformation
from paddlespeech . s2t . utils . dynamic_import import dynamic_import
from paddlespeech . s2t . utils . dynamic_import import dynamic_import
from paddlespeech . s2t . utils . utility import UpdateConfig
from paddlespeech . s2t . utils . utility import UpdateConfig
@ -57,17 +62,17 @@ pretrained_models = {
} ,
} ,
" conformer2online_aishell-zh-16k " : {
" conformer2online_aishell-zh-16k " : {
' url ' :
' url ' :
' https://paddlespeech.bj.bcebos.com/s2t/ aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2 .model.tar.gz' ,
' https://paddlespeech.bj.bcebos.com/s2t/ multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0 .model.tar.gz' ,
' md5 ' :
' md5 ' :
' 4814e52e0fc2fd48899373f95c84b0c9 ' ,
' 7989b3248c898070904cf042fd656003 ' ,
' cfg_path ' :
' cfg_path ' :
' exp/chunk_conformer//conf/config .yaml' ,
' model .yaml' ,
' ckpt_path ' :
' ckpt_path ' :
' exp/chunk_conformer/checkpoints/ avg_30/ ' ,
' exp/chunk_conformer/checkpoints/ multi_cn ' ,
' model ' :
' model ' :
' exp/chunk_conformer/checkpoints/ avg_30 .pdparams' ,
' exp/chunk_conformer/checkpoints/ multi_cn .pdparams' ,
' params ' :
' params ' :
' exp/chunk_conformer/checkpoints/ avg_30 .pdparams' ,
' exp/chunk_conformer/checkpoints/ multi_cn .pdparams' ,
' lm_url ' :
' lm_url ' :
' https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm ' ,
' https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm ' ,
' lm_md5 ' :
' lm_md5 ' :
@ -81,6 +86,23 @@ class ASRServerExecutor(ASRExecutor):
super ( ) . __init__ ( )
super ( ) . __init__ ( )
pass
pass
def _get_pretrained_path ( self , tag : str ) - > os . PathLike :
"""
Download and returns pretrained resources path of current task .
"""
support_models = list ( pretrained_models . keys ( ) )
assert tag in pretrained_models , ' The model " {} " you want to use has not been supported, please choose other models. \n The support models includes: \n \t \t {} \n ' . format (
tag , ' \n \t \t ' . join ( support_models ) )
res_path = os . path . join ( MODEL_HOME , tag )
decompressed_path = download_and_decompress ( pretrained_models [ tag ] ,
res_path )
decompressed_path = os . path . abspath ( decompressed_path )
logger . info (
' Use pretrained model stored in: {} ' . format ( decompressed_path ) )
return decompressed_path
def _init_from_path ( self ,
def _init_from_path ( self ,
model_type : str = ' wenetspeech ' ,
model_type : str = ' wenetspeech ' ,
am_model : Optional [ os . PathLike ] = None ,
am_model : Optional [ os . PathLike ] = None ,
@ -101,7 +123,7 @@ class ASRServerExecutor(ASRExecutor):
logger . info ( f " Load the pretrained model, tag = { tag } " )
logger . info ( f " Load the pretrained model, tag = { tag } " )
res_path = self . _get_pretrained_path ( tag ) # wenetspeech_zh
res_path = self . _get_pretrained_path ( tag ) # wenetspeech_zh
self . res_path = res_path
self . res_path = res_path
self . cfg_path = " /home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/ paddlespeech/server/tests/asr/online/conf/config .yaml"
self . cfg_path = " /home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/ examples/aishell/asr1/model .yaml"
# self.cfg_path = os.path.join(res_path,
# self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path'])
# pretrained_models[tag]['cfg_path'])
@ -147,8 +169,7 @@ class ASRServerExecutor(ASRExecutor):
if self . config . spm_model_prefix :
if self . config . spm_model_prefix :
self . config . spm_model_prefix = os . path . join (
self . config . spm_model_prefix = os . path . join (
self . res_path , self . config . spm_model_prefix )
self . res_path , self . config . spm_model_prefix )
self . config . vocab_filepath = os . path . join (
self . vocab = self . config . vocab_filepath
self . res_path , self . config . vocab_filepath )
self . text_feature = TextFeaturizer (
self . text_feature = TextFeaturizer (
unit_type = self . config . unit_type ,
unit_type = self . config . unit_type ,
vocab = self . config . vocab_filepath ,
vocab = self . config . vocab_filepath ,
@ -203,19 +224,31 @@ class ASRServerExecutor(ASRExecutor):
model_conf = self . config
model_conf = self . config
model = model_class . from_config ( model_conf )
model = model_class . from_config ( model_conf )
self . model = model
self . model = model
self . model . eval ( )
# load model
model_dict = paddle . load ( self . am_model )
self . model . set_state_dict ( model_dict )
logger . info ( " create the transformer like model success " )
logger . info ( " create the transformer like model success " )
# update the ctc decoding
self . searcher = None
self . transformer_decode_reset ( )
def reset_decoder_and_chunk ( self ) :
def reset_decoder_and_chunk ( self ) :
""" reset decoder and chunk state for an new audio
""" reset decoder and chunk state for an new audio
"""
"""
self . decoder . reset_decoder ( batch_size = 1 )
if " deepspeech2online " in self . model_type or " deepspeech2offline " in self . model_type :
# init state box, for new audio request
self . decoder . reset_decoder ( batch_size = 1 )
self . chunk_state_h_box = np . zeros (
# init state box, for new audio request
( self . config . num_rnn_layers , 1 , self . config . rnn_layer_size ) ,
self . chunk_state_h_box = np . zeros (
dtype = float32 )
( self . config . num_rnn_layers , 1 , self . config . rnn_layer_size ) ,
self . chunk_state_c_box = np . zeros (
dtype = float32 )
( self . config . num_rnn_layers , 1 , self . config . rnn_layer_size ) ,
self . chunk_state_c_box = np . zeros (
dtype = float32 )
( self . config . num_rnn_layers , 1 , self . config . rnn_layer_size ) ,
dtype = float32 )
elif " conformer " in self . model_type or " transformer " in self . model_type or " wenetspeech " in self . model_type :
self . transformer_decode_reset ( )
def decode_one_chunk ( self , x_chunk , x_chunk_lens , model_type : str ) :
def decode_one_chunk ( self , x_chunk , x_chunk_lens , model_type : str ) :
""" decode one chunk
""" decode one chunk
@ -275,24 +308,137 @@ class ASRServerExecutor(ASRExecutor):
logger . info (
logger . info (
f " we will use the transformer like model : { self . model_type } "
f " we will use the transformer like model : { self . model_type } "
)
)
cfg = self . config . decode
self . advanced_decoding ( x_chunk , x_chunk_lens )
result_transcripts = self . model . decode (
self . update_result ( )
x_chunk ,
x_chunk_lens ,
return self . result_transcripts [ 0 ]
text_feature = self . text_feature ,
decoding_method = cfg . decoding_method ,
beam_size = cfg . beam_size ,
ctc_weight = cfg . ctc_weight ,
decoding_chunk_size = cfg . decoding_chunk_size ,
num_decoding_left_chunks = cfg . num_decoding_left_chunks ,
simulate_streaming = cfg . simulate_streaming )
return result_transcripts [ 0 ] [ 0 ]
except Exception as e :
except Exception as e :
logger . exception ( e )
logger . exception ( e )
else :
else :
raise Exception ( " invalid model name " )
raise Exception ( " invalid model name " )
def advanced_decoding ( self , xs : paddle . Tensor , x_chunk_lens ) :
logger . info ( " start to decode with advanced_decoding method " )
encoder_out , encoder_mask = self . decode_forward ( xs )
self . ctc_prefix_beam_search ( xs , encoder_out , encoder_mask )
def decode_forward ( self , xs ) :
logger . info ( " get the model out from the feat " )
cfg = self . config . decode
decoding_chunk_size = cfg . decoding_chunk_size
num_decoding_left_chunks = cfg . num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self . model . encoder . embed . subsampling_rate
context = self . model . encoder . embed . right_context + 1
stride = subsampling * decoding_chunk_size
# decoding window for model
decoding_window = ( decoding_chunk_size - 1 ) * subsampling + context
num_frames = xs . shape [ 1 ]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
logger . info ( " start to do model forward " )
outputs = [ ]
# num_frames - context + 1 ensure that current frame can get context window
for cur in range ( 0 , num_frames - context + 1 , stride ) :
end = min ( cur + decoding_window , num_frames )
chunk_xs = xs [ : , cur : end , : ]
( y , self . subsampling_cache , self . elayers_output_cache ,
self . conformer_cnn_cache ) = self . model . encoder . forward_chunk (
chunk_xs , self . offset , required_cache_size ,
self . subsampling_cache , self . elayers_output_cache ,
self . conformer_cnn_cache )
outputs . append ( y )
self . offset + = y . shape [ 1 ]
ys = paddle . cat ( outputs , 1 )
masks = paddle . ones ( [ 1 , ys . shape [ 1 ] ] , dtype = paddle . bool )
masks = masks . unsqueeze ( 1 )
return ys , masks
def transformer_decode_reset ( self ) :
self . subsampling_cache = None
self . elayers_output_cache = None
self . conformer_cnn_cache = None
self . hyps = None
self . offset = 0
self . cur_hyps = None
self . hyps = None
def ctc_prefix_beam_search ( self , xs , encoder_out , encoder_mask , blank_id = 0 ) :
# decode
logger . info ( " start to ctc prefix search " )
device = xs . place
cfg = self . config . decode
batch_size = xs . shape [ 0 ]
beam_size = cfg . beam_size
maxlen = encoder_out . shape [ 1 ]
ctc_probs = self . model . ctc . log_softmax ( encoder_out ) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs . squeeze ( 0 )
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self . cur_hyps is None :
self . cur_hyps = [ ( tuple ( ) , ( 0.0 , - float ( ' inf ' ) ) ) ]
# 2. CTC beam search step by step
for t in range ( 0 , maxlen ) :
logp = ctc_probs [ t ] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict ( lambda : ( - float ( ' inf ' ) , - float ( ' inf ' ) ) )
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp , top_k_index = logp . topk ( beam_size ) # (beam_size,)
for s in top_k_index :
s = s . item ( )
ps = logp [ s ] . item ( )
for prefix , ( pb , pnb ) in self . cur_hyps :
last = prefix [ - 1 ] if len ( prefix ) > 0 else None
if s == blank_id : # blank
n_pb , n_pnb = next_hyps [ prefix ]
n_pb = log_add ( [ n_pb , pb + ps , pnb + ps ] )
next_hyps [ prefix ] = ( n_pb , n_pnb )
elif s == last :
# Update *ss -> *s;
n_pb , n_pnb = next_hyps [ prefix ]
n_pnb = log_add ( [ n_pnb , pnb + ps ] )
next_hyps [ prefix ] = ( n_pb , n_pnb )
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + ( s , )
n_pb , n_pnb = next_hyps [ n_prefix ]
n_pnb = log_add ( [ n_pnb , pb + ps ] )
next_hyps [ n_prefix ] = ( n_pb , n_pnb )
else :
n_prefix = prefix + ( s , )
n_pb , n_pnb = next_hyps [ n_prefix ]
n_pnb = log_add ( [ n_pnb , pb + ps , pnb + ps ] )
next_hyps [ n_prefix ] = ( n_pb , n_pnb )
# 2.2 Second beam prune
next_hyps = sorted (
next_hyps . items ( ) ,
key = lambda x : log_add ( list ( x [ 1 ] ) ) ,
reverse = True )
self . cur_hyps = next_hyps [ : beam_size ]
hyps = [ ( y [ 0 ] , log_add ( [ y [ 1 ] [ 0 ] , y [ 1 ] [ 1 ] ] ) ) for y in self . cur_hyps ]
self . hyps = [ hyps [ 0 ] [ 0 ] ]
logger . info ( " ctc prefix search success " )
return hyps , encoder_out
def update_result ( self ) :
logger . info ( " update the final result " )
self . result_transcripts = [
self . text_feature . defeaturize ( hyp ) for hyp in self . hyps
]
self . result_tokenids = [ hyp for hyp in self . hyps ]
def extract_feat ( self , samples , sample_rate ) :
def extract_feat ( self , samples , sample_rate ) :
""" extract feat
""" extract feat
@ -304,9 +450,10 @@ class ASRServerExecutor(ASRExecutor):
x_chunk ( numpy . array ) : shape [ B , T , D ]
x_chunk ( numpy . array ) : shape [ B , T , D ]
x_chunk_lens ( numpy . array ) : shape [ B ]
x_chunk_lens ( numpy . array ) : shape [ B ]
"""
"""
# pcm16 -> pcm 32
samples = pcm2float ( samples )
if " deepspeech2online " in self . model_type :
if " deepspeech2online " in self . model_type :
# pcm16 -> pcm 32
samples = pcm2float ( samples )
# read audio
# read audio
speech_segment = SpeechSegment . from_pcm (
speech_segment = SpeechSegment . from_pcm (
samples , sample_rate , transcript = " " )
samples , sample_rate , transcript = " " )