fix the code format, test=doc

pull/1704/head
xiongxinlei 2 years ago
parent 380afbbc5d
commit f56dba0ca7

@ -129,7 +129,7 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer": "conformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"conformer2online": "conformer_online":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"transformer": "transformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",

@ -21,7 +21,7 @@ engine_list: ['asr_online']
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online ####################### ################### speech task: asr; engine_type: online #######################
asr_online: asr_online:
model_type: 'conformer2online_aishell' model_type: 'conformer_online_multi-cn'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import os import os
from typing import Optional from typing import Optional
import copy
import numpy as np import numpy as np
import paddle import paddle
from numpy import float32 from numpy import float32
@ -58,7 +59,7 @@ pretrained_models = {
'lm_md5': 'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3' '29e02312deb2e59b3c8686c7966d4fe3'
}, },
"conformer2online_aishell-zh-16k": { "conformer_online_multi-cn-zh-16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
'md5': 'md5':
@ -93,19 +94,22 @@ class PaddleASRConnectionHanddler:
) )
self.config = asr_engine.config self.config = asr_engine.config
self.model_config = asr_engine.executor.config self.model_config = asr_engine.executor.config
# self.model = asr_engine.executor.model
self.asr_engine = asr_engine self.asr_engine = asr_engine
self.init() self.init()
self.reset() self.reset()
def init(self): def init(self):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.sample_rate = self.asr_engine.executor.sample_rate
self.am_predictor = self.asr_engine.executor.am_predictor self.am_predictor = self.asr_engine.executor.am_predictor
self.text_feature = self.asr_engine.executor.text_feature
self.collate_fn_test = SpeechCollator.from_config(self.model_config) self.collate_fn_test = SpeechCollator.from_config(self.model_config)
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab odim=self.model_config.output_dim, # <blank> is in vocab
@ -114,7 +118,8 @@ class PaddleASRConnectionHanddler:
dropout_rate=0.0, dropout_rate=0.0,
reduction=True, # sum reduction=True, # sum
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) grad_norm_type=self.model_config.get('ctc_grad_norm_type',
None))
cfg = self.model_config.decode cfg = self.model_config.decode
decode_batch_size = 1 # for online decode_batch_size = 1 # for online
@ -123,20 +128,16 @@ class PaddleASRConnectionHanddler:
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch) cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length # frame window samples length and frame shift samples length
self.win_length = int(self.model_config.window_ms * self.sample_rate)
self.n_shift = int(self.model_config.stride_ms * self.sample_rate)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: self.win_length = int(self.model_config.window_ms *
self.sample_rate = self.asr_engine.executor.sample_rate self.sample_rate)
self.n_shift = int(self.model_config.stride_ms * self.sample_rate)
elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model # acoustic model
self.model = self.asr_engine.executor.model self.model = self.asr_engine.executor.model
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
# ctc decoding config # ctc decoding config
self.ctc_decode_config = self.asr_engine.executor.config.decode self.ctc_decode_config = self.asr_engine.executor.config.decode
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
@ -189,7 +190,7 @@ class PaddleASRConnectionHanddler:
audio = paddle.to_tensor(audio, dtype='float32') audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len) # audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0) audio = paddle.unsqueeze(audio, axis=0)
if self.cached_feat is None: if self.cached_feat is None:
self.cached_feat = audio self.cached_feat = audio
else: else:
@ -211,7 +212,7 @@ class PaddleASRConnectionHanddler:
logger.info( logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
) )
elif "conformer2online" in self.model_type: elif "conformer_online" in self.model_type:
logger.info("Online ASR extract the feat") logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16) samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1 assert samples.ndim == 1
@ -264,41 +265,43 @@ class PaddleASRConnectionHanddler:
def reset(self): def reset(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
# for deepspeech2 # for deepspeech2
self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box) self.chunk_state_h_box = copy.deepcopy(
self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box) self.asr_engine.executor.chunk_state_h_box)
self.chunk_state_c_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_c_box)
self.decoder.reset_decoder(batch_size=1) self.decoder.reset_decoder(batch_size=1)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
# for conformer online # for conformer online
self.subsampling_cache = None self.subsampling_cache = None
self.elayers_output_cache = None self.elayers_output_cache = None
self.conformer_cnn_cache = None self.conformer_cnn_cache = None
self.encoder_out = None self.encoder_out = None
self.cached_feat = None self.cached_feat = None
self.remained_wav = None self.remained_wav = None
self.offset = 0 self.offset = 0
self.num_samples = 0 self.num_samples = 0
self.device = None self.device = None
self.hyps = [] self.hyps = []
self.num_frames = 0 self.num_frames = 0
self.chunk_num = 0 self.chunk_num = 0
self.global_frame_offset = 0 self.global_frame_offset = 0
self.result_transcripts = [''] self.result_transcripts = ['']
def decode(self, is_finished=False): def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
# x_chunk 是特征数据 # x_chunk 是特征数据
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model
context = 7 # context=7 in deepspeech2 model context = 7 # context=7 in deepspeech2 model
subsampling = 4 # subsampling=4 in deepspeech2 model subsampling = 4 # subsampling=4 in deepspeech2 model
stride = subsampling * decoding_chunk_size stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling cached_feature_num = context - subsampling
# decoding window for model # decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None: if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data") logger.info("no audio feat, please input more pcm data")
return return
num_frames = self.cached_feat.shape[1] num_frames = self.cached_feat.shape[1]
logger.info( logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
@ -306,14 +309,14 @@ class PaddleASRConnectionHanddler:
# the cached feat must be larger decoding_window # the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished: if num_frames < decoding_window and not is_finished:
logger.info( logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data" f"frame feat num is less than {decoding_window}, please input more pcm data"
) )
return None, None return None, None
# if is_finished=True, we need at least context frames # if is_finished=True, we need at least context frames
if num_frames < context: if num_frames < context:
logger.info( logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward" "flast {num_frames} is less than context {context} frames, and we cannot do model forward"
) )
return None, None return None, None
logger.info("start to do model forward") logger.info("start to do model forward")
@ -334,8 +337,7 @@ class PaddleASRConnectionHanddler:
self.result_transcripts = [trans_best] self.result_transcripts = [trans_best]
self.cached_feat = self.cached_feat[:, end - self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
cached_feature_num:, :]
# return trans_best[0] # return trans_best[0]
elif "conformer" in self.model_type or "transformer" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
try: try:
@ -354,8 +356,7 @@ class PaddleASRConnectionHanddler:
logger.info("start to decoce one chunk with deepspeech2 model") logger.info("start to decoce one chunk with deepspeech2 model")
input_names = self.am_predictor.get_input_names() input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle( audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2]) h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3]) c_box_handle = self.am_predictor.get_input_handle(input_names[3])
@ -374,11 +375,11 @@ class PaddleASRConnectionHanddler:
output_names = self.am_predictor.get_output_names() output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0]) output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle( output_lens_handle = self.am_predictor.get_output_handle(
output_names[1]) output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle( output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2]) output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle( output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3]) output_names[3])
self.am_predictor.run() self.am_predictor.run()
@ -389,7 +390,7 @@ class PaddleASRConnectionHanddler:
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode() trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one one best result: {trans_best[0]}") logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0] return trans_best[0]
def advance_decoding(self, is_finished=False): def advance_decoding(self, is_finished=False):
@ -500,7 +501,7 @@ class PaddleASRConnectionHanddler:
def rescoring(self): def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
return return
logger.info("rescoring the final result") logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method: if "attention_rescoring" != self.ctc_decode_config.decoding_method:
return return
@ -587,7 +588,7 @@ class ASRServerExecutor(ASRExecutor):
return decompressed_path return decompressed_path
def _init_from_path(self, def _init_from_path(self,
model_type: str='wenetspeech', model_type: str='deepspeech2online_aishell',
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None,
lang: str='zh', lang: str='zh',
@ -647,7 +648,7 @@ class ASRServerExecutor(ASRExecutor):
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type:
logger.info("start to create the stream conformer asr engine") logger.info("start to create the stream conformer asr engine")
if self.config.spm_model_prefix: if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join( self.config.spm_model_prefix = os.path.join(
@ -711,7 +712,7 @@ class ASRServerExecutor(ASRExecutor):
self.chunk_state_c_box = np.zeros( self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32) dtype=float32)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}") logger.info(f"model name: {model_name}")
@ -742,7 +743,7 @@ class ASRServerExecutor(ASRExecutor):
self.chunk_state_c_box = np.zeros( self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32) dtype=float32)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
self.transformer_decode_reset() self.transformer_decode_reset()
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
@ -754,7 +755,7 @@ class ASRServerExecutor(ASRExecutor):
model_type (str): online model type model_type (str): online model type
Returns: Returns:
[type]: [description] str: one best result
""" """
logger.info("start to decoce chunk by chunk") logger.info("start to decoce chunk by chunk")
if "deepspeech2online" in model_type: if "deepspeech2online" in model_type:
@ -795,7 +796,7 @@ class ASRServerExecutor(ASRExecutor):
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode() trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one one best result: {trans_best[0]}") logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0] return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
@ -972,7 +973,7 @@ class ASRServerExecutor(ASRExecutor):
x_chunk_lens = np.array([audio_len]) x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens return x_chunk, x_chunk_lens
elif "conformer2online" in self.model_type: elif "conformer_online" in self.model_type:
if sample_rate != self.sample_rate: if sample_rate != self.sample_rate:
logger.info(f"audio sample rate {sample_rate} is not match," logger.info(f"audio sample rate {sample_rate} is not match,"
@ -1005,7 +1006,7 @@ class ASREngine(BaseEngine):
def __init__(self): def __init__(self):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
logger.info("create the online asr engine instache") logger.info("create the online asr engine instance")
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
"""init engine resource """init engine resource

Loading…
Cancel
Save