# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
import numpy as np
import paddle
from numpy import float32
from yacs . config import CfgNode
from paddlespeech . cli . asr . infer import ASRExecutor
from paddlespeech . cli . asr . infer import model_alias
from paddlespeech . cli . log import logger
from paddlespeech . cli . utils import download_and_decompress
from paddlespeech . cli . utils import MODEL_HOME
from paddlespeech . s2t . frontend . featurizer . text_featurizer import TextFeaturizer
from paddlespeech . s2t . frontend . speech import SpeechSegment
from paddlespeech . s2t . modules . ctc import CTCDecoder
from paddlespeech . s2t . transform . transformation import Transformation
from paddlespeech . s2t . utils . dynamic_import import dynamic_import
from paddlespeech . s2t . utils . tensor_utils import add_sos_eos
from paddlespeech . s2t . utils . tensor_utils import pad_sequence
from paddlespeech . s2t . utils . utility import UpdateConfig
from paddlespeech . server . engine . asr . online . ctc_search import CTCPrefixBeamSearch
from paddlespeech . server . engine . base_engine import BaseEngine
from paddlespeech . server . utils . audio_process import pcm2float
from paddlespeech . server . utils . paddle_predictor import init_predictor
__all__ = [ ' ASREngine ' ]
pretrained_models = {
" deepspeech2online_aishell-zh-16k " : {
' url ' :
' https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ' ,
' md5 ' :
' 23e16c69730a1cb5d735c98c83c21e16 ' ,
' cfg_path ' :
' model.yaml ' ,
' ckpt_path ' :
' exp/deepspeech2_online/checkpoints/avg_1 ' ,
' model ' :
' exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel ' ,
' params ' :
' exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams ' ,
' lm_url ' :
' https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm ' ,
' lm_md5 ' :
' 29e02312deb2e59b3c8686c7966d4fe3 '
} ,
" conformer2online_aishell-zh-16k " : {
' url ' :
' https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz ' ,
' md5 ' :
' b450d5dfaea0ac227c595ce58d18b637 ' ,
' cfg_path ' :
' model.yaml ' ,
' ckpt_path ' :
' exp/chunk_conformer/checkpoints/multi_cn ' ,
' model ' :
' exp/chunk_conformer/checkpoints/multi_cn.pdparams ' ,
' params ' :
' exp/chunk_conformer/checkpoints/multi_cn.pdparams ' ,
' lm_url ' :
' https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm ' ,
' lm_md5 ' :
' 29e02312deb2e59b3c8686c7966d4fe3 '
} ,
}
# ASR server connection process class
class PaddleASRConnectionHanddler :
def __init__ ( self , asr_engine ) :
super ( ) . __init__ ( )
logger . info ( " create an paddle asr connection handler to process the websocket connection " )
self . config = asr_engine . config
self . model_config = asr_engine . executor . config
self . model = asr_engine . executor . model
self . asr_engine = asr_engine
self . init ( )
self . reset ( )
def init ( self ) :
self . model_type = self . asr_engine . executor . model_type
if " deepspeech2online " in self . model_type or " deepspeech2offline " in self . model_type :
pass
elif " conformer " in self . model_type or " transformer " in self . model_type or " wenetspeech " in self . model_type :
self . sample_rate = self . asr_engine . executor . sample_rate
# acoustic model
self . model = self . asr_engine . executor . model
# tokens to text
self . text_feature = self . asr_engine . executor . text_feature
# ctc decoding
self . ctc_decode_config = self . asr_engine . executor . config . decode
self . searcher = CTCPrefixBeamSearch ( self . ctc_decode_config )
# extract fbank
self . preprocess_conf = self . model_config . preprocess_config
self . preprocess_args = { " train " : False }
self . preprocessing = Transformation ( self . preprocess_conf )
self . win_length = self . preprocess_conf . process [ 0 ] [ ' win_length ' ]
self . n_shift = self . preprocess_conf . process [ 0 ] [ ' n_shift ' ]
def extract_feat ( self , samples ) :
if " deepspeech2online " in self . model_type :
pass
elif " conformer2online " in self . model_type :
logger . info ( " Online ASR extract the feat " )
samples = np . frombuffer ( samples , dtype = np . int16 )
assert samples . ndim == 1
logger . info ( f " This package receive { samples . shape [ 0 ] } pcm data " )
self . num_samples + = samples . shape [ 0 ]
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self . remained_wav is None :
self . remained_wav = samples
else :
assert self . remained_wav . ndim == 1
self . remained_wav = np . concatenate ( [ self . remained_wav , samples ] )
logger . info (
f " The connection remain the audio samples: { self . remained_wav . shape } "
)
if len ( self . remained_wav ) < self . win_length :
return 0
# fbank
x_chunk = self . preprocessing ( self . remained_wav , * * self . preprocess_args )
x_chunk = paddle . to_tensor (
x_chunk , dtype = " float32 " ) . unsqueeze ( axis = 0 )
if self . cached_feat is None :
self . cached_feat = x_chunk
else :
assert ( len ( x_chunk . shape ) == 3 )
assert ( len ( self . cached_feat . shape ) == 3 )
self . cached_feat = paddle . concat ( [ self . cached_feat , x_chunk ] , axis = 1 )
# set the feat device
if self . device is None :
self . device = self . cached_feat . place
num_frames = x_chunk . shape [ 1 ]
self . num_frames + = num_frames
self . remained_wav = self . remained_wav [ self . n_shift * num_frames : ]
logger . info (
f " process the audio feature success, the connection feat shape: { self . cached_feat . shape } "
)
logger . info (
f " After extract feat, the connection remain the audio samples: { self . remained_wav . shape } "
)
# logger.info(f"accumulate samples: {self.num_samples}")
def reset ( self ) :
self . subsampling_cache = None
self . elayers_output_cache = None
self . conformer_cnn_cache = None
self . encoder_out = None
self . cached_feat = None
self . remained_wav = None
self . offset = 0
self . num_samples = 0
self . device = None
self . hyps = [ ]
self . num_frames = 0
self . chunk_num = 0
self . global_frame_offset = 0
self . result_transcripts = [ ' ' ]
def decode ( self , is_finished = False ) :
if " deepspeech2online " in self . model_type :
pass
elif " conformer " in self . model_type or " transformer " in self . model_type :
try :
logger . info (
f " we will use the transformer like model : { self . model_type } "
)
self . advance_decoding ( is_finished )
self . update_result ( )
except Exception as e :
logger . exception ( e )
else :
raise Exception ( " invalid model name " )
def advance_decoding ( self , is_finished = False ) :
logger . info ( " start to decode with advanced_decoding method " )
cfg = self . ctc_decode_config
decoding_chunk_size = cfg . decoding_chunk_size
num_decoding_left_chunks = cfg . num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self . model . encoder . embed . subsampling_rate
context = self . model . encoder . embed . right_context + 1
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
# decoding window for model
decoding_window = ( decoding_chunk_size - 1 ) * subsampling + context
if self . cached_feat is None :
logger . info ( " no audio feat, please input more pcm data " )
return
num_frames = self . cached_feat . shape [ 1 ]
logger . info ( f " Required decoding window { decoding_window } frames, and the connection has { num_frames } frames " )
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished :
logger . info ( f " frame feat num is less than { decoding_window } , please input more pcm data " )
return None , None
if num_frames < context :
logger . info ( " flast {num_frames} is less than context {context} frames, and we cannot do model forward " )
return None , None
logger . info ( " start to do model forward " )
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = [ ]
# num_frames - context + 1 ensure that current frame can get context window
if is_finished :
# if get the finished chunk, we need process the last context
left_frames = context
else :
# we only process decoding_window frames for one chunk
left_frames = decoding_window
# record the end for removing the processed feat
end = None
for cur in range ( 0 , num_frames - left_frames + 1 , stride ) :
end = min ( cur + decoding_window , num_frames )
self . chunk_num + = 1
chunk_xs = self . cached_feat [ : , cur : end , : ]
( y , self . subsampling_cache , self . elayers_output_cache ,
self . conformer_cnn_cache ) = self . model . encoder . forward_chunk (
chunk_xs , self . offset , required_cache_size ,
self . subsampling_cache , self . elayers_output_cache ,
self . conformer_cnn_cache )
outputs . append ( y )
# update the offset
self . offset + = y . shape [ 1 ]
logger . info ( f " output size: { len ( outputs ) } " )
ys = paddle . cat ( outputs , 1 )
if self . encoder_out is None :
self . encoder_out = ys
else :
self . encoder_out = paddle . concat ( [ self . encoder_out , ys ] , axis = 1 )
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# get the ctc probs
ctc_probs = self . model . ctc . log_softmax ( ys ) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs . squeeze ( 0 )
self . searcher . search ( None , ctc_probs , self . cached_feat . place )
self . hyps = self . searcher . get_one_best_hyps ( )
# remove the processed feat
if end == num_frames :
self . cached_feat = None
else :
assert self . cached_feat . shape [ 0 ] == 1
assert end > = cached_feature_num
self . cached_feat = self . cached_feat [ 0 , end - cached_feature_num : , : ] . unsqueeze ( 0 )
assert len ( self . cached_feat . shape ) == 3 , f " current cache feat shape is: { self . cached_feat . shape } "
# ys for rescoring
# return ys, masks
def update_result ( self ) :
logger . info ( " update the final result " )
hyps = self . hyps
self . result_transcripts = [
self . text_feature . defeaturize ( hyp ) for hyp in hyps
]
self . result_tokenids = [ hyp for hyp in hyps ]
def get_result ( self ) :
if len ( self . result_transcripts ) > 0 :
return self . result_transcripts [ 0 ]
else :
return ' '
def rescoring ( self ) :
logger . info ( " rescoring the final result " )
if " attention_rescoring " != self . ctc_decode_config . decoding_method :
return
self . searcher . finalize_search ( )
self . update_result ( )
beam_size = self . ctc_decode_config . beam_size
hyps = self . searcher . get_hyps ( )
if hyps is None or len ( hyps ) == 0 :
return
# assert len(hyps) == beam_size
paddle . save ( self . encoder_out , " encoder.out " )
hyp_list = [ ]
for hyp in hyps :
hyp_content = hyp [ 0 ]
# Prevent the hyp is empty
if len ( hyp_content ) == 0 :
hyp_content = ( self . model . ctc . blank_id , )
hyp_content = paddle . to_tensor (
hyp_content , place = self . device , dtype = paddle . long )
hyp_list . append ( hyp_content )
hyps_pad = pad_sequence ( hyp_list , True , self . model . ignore_id )
hyps_lens = paddle . to_tensor (
[ len ( hyp [ 0 ] ) for hyp in hyps ] , place = self . device ,
dtype = paddle . long ) # (beam_size,)
hyps_pad , _ = add_sos_eos ( hyps_pad , self . model . sos , self . model . eos ,
self . model . ignore_id )
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = self . encoder_out . repeat ( beam_size , 1 , 1 )
encoder_mask = paddle . ones (
( beam_size , 1 , encoder_out . shape [ 1 ] ) , dtype = paddle . bool )
decoder_out , _ = self . model . decoder (
encoder_out , encoder_mask , hyps_pad ,
hyps_lens ) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle . nn . functional . log_softmax ( decoder_out , axis = - 1 )
decoder_out = decoder_out . numpy ( )
# Only use decoder score for rescoring
best_score = - float ( ' inf ' )
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i , hyp in enumerate ( hyps ) :
score = 0.0
for j , w in enumerate ( hyp [ 0 ] ) :
score + = decoder_out [ i ] [ j ] [ w ]
# last decoder output token is `eos`, for laste decoder input token.
score + = decoder_out [ i ] [ len ( hyp [ 0 ] ) ] [ self . model . eos ]
# add ctc score (which in ln domain)
score + = hyp [ 1 ] * self . ctc_decode_config . ctc_weight
if score > best_score :
best_score = score
best_index = i
# update the one best result
logger . info ( f " best index: { best_index } " )
self . hyps = [ hyps [ best_index ] [ 0 ] ]
self . update_result ( )
# return hyps[best_index][0]
class ASRServerExecutor ( ASRExecutor ) :
def __init__ ( self ) :
super ( ) . __init__ ( )
pass
def _get_pretrained_path ( self , tag : str ) - > os . PathLike :
"""
Download and returns pretrained resources path of current task .
"""
support_models = list ( pretrained_models . keys ( ) )
assert tag in pretrained_models , ' The model " {} " you want to use has not been supported, please choose other models. \n The support models includes: \n \t \t {} \n ' . format (
tag , ' \n \t \t ' . join ( support_models ) )
res_path = os . path . join ( MODEL_HOME , tag )
decompressed_path = download_and_decompress ( pretrained_models [ tag ] ,
res_path )
decompressed_path = os . path . abspath ( decompressed_path )
logger . info (
' Use pretrained model stored in: {} ' . format ( decompressed_path ) )
return decompressed_path
def _init_from_path ( self ,
model_type : str = ' wenetspeech ' ,
am_model : Optional [ os . PathLike ] = None ,
am_params : Optional [ os . PathLike ] = None ,
lang : str = ' zh ' ,
sample_rate : int = 16000 ,
cfg_path : Optional [ os . PathLike ] = None ,
decode_method : str = ' attention_rescoring ' ,
am_predictor_conf : dict = None ) :
"""
Init model and other resources from a specific path .
"""
self . model_type = model_type
self . sample_rate = sample_rate
if cfg_path is None or am_model is None or am_params is None :
sample_rate_str = ' 16k ' if sample_rate == 16000 else ' 8k '
tag = model_type + ' - ' + lang + ' - ' + sample_rate_str
logger . info ( f " Load the pretrained model, tag = { tag } " )
res_path = self . _get_pretrained_path ( tag ) # wenetspeech_zh
self . res_path = res_path
self . cfg_path = " /home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml "
# self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path'])
self . am_model = os . path . join ( res_path ,
pretrained_models [ tag ] [ ' model ' ] )
self . am_params = os . path . join ( res_path ,
pretrained_models [ tag ] [ ' params ' ] )
logger . info ( res_path )
else :
self . cfg_path = os . path . abspath ( cfg_path )
self . am_model = os . path . abspath ( am_model )
self . am_params = os . path . abspath ( am_params )
self . res_path = os . path . dirname (
os . path . dirname ( os . path . abspath ( self . cfg_path ) ) )
logger . info ( self . cfg_path )
logger . info ( self . am_model )
logger . info ( self . am_params )
#Init body.
self . config = CfgNode ( new_allowed = True )
self . config . merge_from_file ( self . cfg_path )
with UpdateConfig ( self . config ) :
if " deepspeech2online " in model_type or " deepspeech2offline " in model_type :
from paddlespeech . s2t . io . collator import SpeechCollator
self . vocab = self . config . vocab_filepath
self . config . decode . lang_model_path = os . path . join (
MODEL_HOME , ' language_model ' ,
self . config . decode . lang_model_path )
self . collate_fn_test = SpeechCollator . from_config ( self . config )
self . text_feature = TextFeaturizer (
unit_type = self . config . unit_type , vocab = self . vocab )
lm_url = pretrained_models [ tag ] [ ' lm_url ' ]
lm_md5 = pretrained_models [ tag ] [ ' lm_md5 ' ]
logger . info ( f " Start to load language model { lm_url } " )
self . download_lm (
lm_url ,
os . path . dirname ( self . config . decode . lang_model_path ) , lm_md5 )
elif " conformer " in model_type or " transformer " in model_type or " wenetspeech " in model_type :
logger . info ( " start to create the stream conformer asr engine " )
if self . config . spm_model_prefix :
self . config . spm_model_prefix = os . path . join (
self . res_path , self . config . spm_model_prefix )
self . vocab = self . config . vocab_filepath
self . text_feature = TextFeaturizer (
unit_type = self . config . unit_type ,
vocab = self . config . vocab_filepath ,
spm_model_prefix = self . config . spm_model_prefix )
# update the decoding method
if decode_method :
self . config . decode . decoding_method = decode_method
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self . config . decode . decoding_method not in [
" ctc_prefix_beam_search " , " attention_rescoring "
] :
logger . info (
" we set the decoding_method to attention_rescoring " )
self . config . decode . decoding = " attention_rescoring "
assert self . config . decode . decoding_method in [
" ctc_prefix_beam_search " , " attention_rescoring "
] , f " we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is { self . config . decode . decoding_method } "
else :
raise Exception ( " wrong type " )
if " deepspeech2online " in model_type or " deepspeech2offline " in model_type :
# AM predictor
logger . info ( " ASR engine start to init the am predictor " )
self . am_predictor_conf = am_predictor_conf
self . am_predictor = init_predictor (
model_file = self . am_model ,
params_file = self . am_params ,
predictor_conf = self . am_predictor_conf )
# decoder
logger . info ( " ASR engine start to create the ctc decoder instance " )
self . decoder = CTCDecoder (
odim = self . config . output_dim , # <blank> is in vocab
enc_n_units = self . config . rnn_layer_size * 2 ,
blank_id = self . config . blank_id ,
dropout_rate = 0.0 ,
reduction = True , # sum
batch_average = True , # sum / batch_size
grad_norm_type = self . config . get ( ' ctc_grad_norm_type ' , None ) )
# init decoder
logger . info ( " ASR engine start to init the ctc decoder " )
cfg = self . config . decode
decode_batch_size = 1 # for online
self . decoder . init_decoder (
decode_batch_size , self . text_feature . vocab_list ,
cfg . decoding_method , cfg . lang_model_path , cfg . alpha , cfg . beta ,
cfg . beam_size , cfg . cutoff_prob , cfg . cutoff_top_n ,
cfg . num_proc_bsearch )
# init state box
self . chunk_state_h_box = np . zeros (
( self . config . num_rnn_layers , 1 , self . config . rnn_layer_size ) ,
dtype = float32 )
self . chunk_state_c_box = np . zeros (
( self . config . num_rnn_layers , 1 , self . config . rnn_layer_size ) ,
dtype = float32 )
elif " conformer " in model_type or " transformer " in model_type or " wenetspeech " in model_type :
model_name = model_type [ : model_type . rindex (
' _ ' ) ] # model_type: {model_name}_{dataset}
logger . info ( f " model name: { model_name } " )
model_class = dynamic_import ( model_name , model_alias )
model_conf = self . config
model = model_class . from_config ( model_conf )
self . model = model
self . model . eval ( )
# load model
model_dict = paddle . load ( self . am_model )
self . model . set_state_dict ( model_dict )
logger . info ( " create the transformer like model success " )
# update the ctc decoding
self . searcher = CTCPrefixBeamSearch ( self . config . decode )
self . transformer_decode_reset ( )
def reset_decoder_and_chunk ( self ) :
""" reset decoder and chunk state for an new audio
"""
if " deepspeech2online " in self . model_type or " deepspeech2offline " in self . model_type :
self . decoder . reset_decoder ( batch_size = 1 )
# init state box, for new audio request
self . chunk_state_h_box = np . zeros (
( self . config . num_rnn_layers , 1 , self . config . rnn_layer_size ) ,
dtype = float32 )
self . chunk_state_c_box = np . zeros (
( self . config . num_rnn_layers , 1 , self . config . rnn_layer_size ) ,
dtype = float32 )
elif " conformer " in self . model_type or " transformer " in self . model_type or " wenetspeech " in self . model_type :
self . transformer_decode_reset ( )
def decode_one_chunk ( self , x_chunk , x_chunk_lens , model_type : str ) :
""" decode one chunk
Args :
x_chunk ( numpy . array ) : shape [ B , T , D ]
x_chunk_lens ( numpy . array ) : shape [ B ]
model_type ( str ) : online model type
Returns :
[ type ] : [ description ]
"""
logger . info ( " start to decoce chunk by chunk " )
if " deepspeech2online " in model_type :
input_names = self . am_predictor . get_input_names ( )
audio_handle = self . am_predictor . get_input_handle ( input_names [ 0 ] )
audio_len_handle = self . am_predictor . get_input_handle (
input_names [ 1 ] )
h_box_handle = self . am_predictor . get_input_handle ( input_names [ 2 ] )
c_box_handle = self . am_predictor . get_input_handle ( input_names [ 3 ] )
audio_handle . reshape ( x_chunk . shape )
audio_handle . copy_from_cpu ( x_chunk )
audio_len_handle . reshape ( x_chunk_lens . shape )
audio_len_handle . copy_from_cpu ( x_chunk_lens )
h_box_handle . reshape ( self . chunk_state_h_box . shape )
h_box_handle . copy_from_cpu ( self . chunk_state_h_box )
c_box_handle . reshape ( self . chunk_state_c_box . shape )
c_box_handle . copy_from_cpu ( self . chunk_state_c_box )
output_names = self . am_predictor . get_output_names ( )
output_handle = self . am_predictor . get_output_handle ( output_names [ 0 ] )
output_lens_handle = self . am_predictor . get_output_handle (
output_names [ 1 ] )
output_state_h_handle = self . am_predictor . get_output_handle (
output_names [ 2 ] )
output_state_c_handle = self . am_predictor . get_output_handle (
output_names [ 3 ] )
self . am_predictor . run ( )
output_chunk_probs = output_handle . copy_to_cpu ( )
output_chunk_lens = output_lens_handle . copy_to_cpu ( )
self . chunk_state_h_box = output_state_h_handle . copy_to_cpu ( )
self . chunk_state_c_box = output_state_c_handle . copy_to_cpu ( )
self . decoder . next ( output_chunk_probs , output_chunk_lens )
trans_best , trans_beam = self . decoder . decode ( )
logger . info ( f " decode one one best result: { trans_best [ 0 ] } " )
return trans_best [ 0 ]
elif " conformer " in model_type or " transformer " in model_type :
try :
logger . info (
f " we will use the transformer like model : { self . model_type } "
)
self . advanced_decoding ( x_chunk , x_chunk_lens )
self . update_result ( )
return self . result_transcripts [ 0 ]
except Exception as e :
logger . exception ( e )
else :
raise Exception ( " invalid model name " )
def advanced_decoding ( self , xs : paddle . Tensor , x_chunk_lens ) :
logger . info ( " start to decode with advanced_decoding method " )
encoder_out , encoder_mask = self . decode_forward ( xs )
ctc_probs = self . model . ctc . log_softmax (
encoder_out ) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs . squeeze ( 0 )
self . searcher . search ( xs , ctc_probs , xs . place )
# update the one best result
self . hyps = self . searcher . get_one_best_hyps ( )
# now we supprot ctc_prefix_beam_search and attention_rescoring
if " attention_rescoring " in self . config . decode . decoding_method :
self . rescoring ( encoder_out , xs . place )
def decode_forward ( self , xs ) :
logger . info ( " get the model out from the feat " )
cfg = self . config . decode
decoding_chunk_size = cfg . decoding_chunk_size
num_decoding_left_chunks = cfg . num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self . model . encoder . embed . subsampling_rate
context = self . model . encoder . embed . right_context + 1
stride = subsampling * decoding_chunk_size
# decoding window for model
decoding_window = ( decoding_chunk_size - 1 ) * subsampling + context
num_frames = xs . shape [ 1 ]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
logger . info ( " start to do model forward " )
outputs = [ ]
# num_frames - context + 1 ensure that current frame can get context window
for cur in range ( 0 , num_frames - context + 1 , stride ) :
end = min ( cur + decoding_window , num_frames )
chunk_xs = xs [ : , cur : end , : ]
( y , self . subsampling_cache , self . elayers_output_cache ,
self . conformer_cnn_cache ) = self . model . encoder . forward_chunk (
chunk_xs , self . offset , required_cache_size ,
self . subsampling_cache , self . elayers_output_cache ,
self . conformer_cnn_cache )
outputs . append ( y )
self . offset + = y . shape [ 1 ]
ys = paddle . cat ( outputs , 1 )
masks = paddle . ones ( [ 1 , ys . shape [ 1 ] ] , dtype = paddle . bool )
masks = masks . unsqueeze ( 1 )
return ys , masks
def rescoring ( self , encoder_out , device ) :
logger . info ( " start to rescoring the hyps " )
beam_size = self . config . decode . beam_size
hyps = self . searcher . get_hyps ( )
assert len ( hyps ) == beam_size
hyp_list = [ ]
for hyp in hyps :
hyp_content = hyp [ 0 ]
# Prevent the hyp is empty
if len ( hyp_content ) == 0 :
hyp_content = ( self . model . ctc . blank_id , )
hyp_content = paddle . to_tensor (
hyp_content , place = device , dtype = paddle . long )
hyp_list . append ( hyp_content )
hyps_pad = pad_sequence ( hyp_list , True , self . model . ignore_id )
hyps_lens = paddle . to_tensor (
[ len ( hyp [ 0 ] ) for hyp in hyps ] , place = device ,
dtype = paddle . long ) # (beam_size,)
hyps_pad , _ = add_sos_eos ( hyps_pad , self . model . sos , self . model . eos ,
self . model . ignore_id )
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out . repeat ( beam_size , 1 , 1 )
encoder_mask = paddle . ones (
( beam_size , 1 , encoder_out . shape [ 1 ] ) , dtype = paddle . bool )
decoder_out , _ = self . model . decoder (
encoder_out , encoder_mask , hyps_pad ,
hyps_lens ) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle . nn . functional . log_softmax ( decoder_out , axis = - 1 )
decoder_out = decoder_out . numpy ( )
# Only use decoder score for rescoring
best_score = - float ( ' inf ' )
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i , hyp in enumerate ( hyps ) :
score = 0.0
for j , w in enumerate ( hyp [ 0 ] ) :
score + = decoder_out [ i ] [ j ] [ w ]
# last decoder output token is `eos`, for laste decoder input token.
score + = decoder_out [ i ] [ len ( hyp [ 0 ] ) ] [ self . model . eos ]
# add ctc score (which in ln domain)
score + = hyp [ 1 ] * self . config . decode . ctc_weight
if score > best_score :
best_score = score
best_index = i
# update the one best result
self . hyps = [ hyps [ best_index ] [ 0 ] ]
return hyps [ best_index ] [ 0 ]
def transformer_decode_reset ( self ) :
self . subsampling_cache = None
self . elayers_output_cache = None
self . conformer_cnn_cache = None
self . offset = 0
# decoding reset
self . searcher . reset ( )
def update_result ( self ) :
logger . info ( " update the final result " )
hyps = self . hyps
self . result_transcripts = [
self . text_feature . defeaturize ( hyp ) for hyp in hyps
]
self . result_tokenids = [ hyp for hyp in hyps ]
def extract_feat ( self , samples , sample_rate ) :
""" extract feat
Args :
samples ( numpy . array ) : numpy . float32
sample_rate ( int ) : sample rate
Returns :
x_chunk ( numpy . array ) : shape [ B , T , D ]
x_chunk_lens ( numpy . array ) : shape [ B ]
"""
if " deepspeech2online " in self . model_type :
# pcm16 -> pcm 32
samples = pcm2float ( samples )
# read audio
speech_segment = SpeechSegment . from_pcm (
samples , sample_rate , transcript = " " )
# audio augment
self . collate_fn_test . augmentation . transform_audio ( speech_segment )
# extract speech feature
spectrum , transcript_part = self . collate_fn_test . _speech_featurizer . featurize (
speech_segment , self . collate_fn_test . keep_transcription_text )
# CMVN spectrum
if self . collate_fn_test . _normalizer :
spectrum = self . collate_fn_test . _normalizer . apply ( spectrum )
# spectrum augment
audio = self . collate_fn_test . augmentation . transform_feature (
spectrum )
audio_len = audio . shape [ 0 ]
audio = paddle . to_tensor ( audio , dtype = ' float32 ' )
# audio_len = paddle.to_tensor(audio_len)
audio = paddle . unsqueeze ( audio , axis = 0 )
x_chunk = audio . numpy ( )
x_chunk_lens = np . array ( [ audio_len ] )
return x_chunk , x_chunk_lens
elif " conformer2online " in self . model_type :
if sample_rate != self . sample_rate :
logger . info ( f " audio sample rate { sample_rate } is not match, "
" the model sample_rate is {self.sample_rate} " )
logger . info ( f " ASR Engine use the { self . model_type } to process " )
logger . info ( " Create the preprocess instance " )
preprocess_conf = self . config . preprocess_config
preprocess_args = { " train " : False }
preprocessing = Transformation ( preprocess_conf )
logger . info ( " Read the audio file " )
logger . info ( f " audio shape: { samples . shape } " )
# fbank
x_chunk = preprocessing ( samples , * * preprocess_args )
x_chunk_lens = paddle . to_tensor ( x_chunk . shape [ 0 ] )
x_chunk = paddle . to_tensor (
x_chunk , dtype = " float32 " ) . unsqueeze ( axis = 0 )
logger . info (
f " process the audio feature success, feat shape: { x_chunk . shape } "
)
return x_chunk , x_chunk_lens
class ASREngine ( BaseEngine ) :
""" ASR server engine
Args :
metaclass : Defaults to Singleton .
"""
def __init__ ( self ) :
super ( ASREngine , self ) . __init__ ( )
logger . info ( " create the online asr engine instache " )
def init ( self , config : dict ) - > bool :
""" init engine resource
Args :
config_file ( str ) : config file
Returns :
bool : init failed or success
"""
self . input = None
self . output = " "
self . executor = ASRServerExecutor ( )
self . config = config
self . executor . _init_from_path (
model_type = self . config . model_type ,
am_model = self . config . am_model ,
am_params = self . config . am_params ,
lang = self . config . lang ,
sample_rate = self . config . sample_rate ,
cfg_path = self . config . cfg_path ,
decode_method = self . config . decode_method ,
am_predictor_conf = self . config . am_predictor_conf )
logger . info ( " Initialize ASR server engine successfully. " )
return True
def preprocess ( self ,
samples ,
sample_rate ,
model_type = " deepspeech2online_aishell-zh-16k " ) :
""" preprocess
Args :
samples ( numpy . array ) : numpy . float32
sample_rate ( int ) : sample rate
Returns :
x_chunk ( numpy . array ) : shape [ B , T , D ]
x_chunk_lens ( numpy . array ) : shape [ B ]
"""
# if "deepspeech" in model_type:
x_chunk , x_chunk_lens = self . executor . extract_feat ( samples , sample_rate )
return x_chunk , x_chunk_lens
def run ( self , x_chunk , x_chunk_lens , decoder_chunk_size = 1 ) :
""" run online engine
Args :
x_chunk ( numpy . array ) : shape [ B , T , D ]
x_chunk_lens ( numpy . array ) : shape [ B ]
decoder_chunk_size ( int )
"""
self . output = self . executor . decode_one_chunk ( x_chunk , x_chunk_lens ,
self . config . model_type )
def postprocess ( self ) :
""" postprocess
"""
return self . output
def reset ( self ) :
""" reset engine decoder and inference state
"""
self . executor . reset_decoder_and_chunk ( )
self . output = " "