diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 53f71a70..f1e46ca1 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -129,7 +129,7 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", - "conformer2online": + "conformer_online": "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index 1a775f85..89a861ef 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -21,7 +21,7 @@ engine_list: ['asr_online'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer2online_aishell' + model_type: 'conformer_online_multi-cn' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 77eb5a21..3c2b066c 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import os from typing import Optional -import copy + import numpy as np import paddle from numpy import float32 @@ -58,7 +59,7 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, - "conformer2online_aishell-zh-16k": { + "conformer_online_multi-cn-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': @@ -93,19 +94,22 @@ class PaddleASRConnectionHanddler: ) self.config = asr_engine.config self.model_config = asr_engine.executor.config - # self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() self.reset() def init(self): + # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type + self.sample_rate = self.asr_engine.executor.sample_rate + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: from paddlespeech.s2t.io.collator import SpeechCollator - self.sample_rate = self.asr_engine.executor.sample_rate self.am_predictor = self.asr_engine.executor.am_predictor - self.text_feature = self.asr_engine.executor.text_feature + self.collate_fn_test = SpeechCollator.from_config(self.model_config) self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab @@ -114,7 +118,8 @@ class PaddleASRConnectionHanddler: dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size - grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) + grad_norm_type=self.model_config.get('ctc_grad_norm_type', + None)) cfg = self.model_config.decode decode_batch_size = 1 # for online @@ -123,20 +128,16 @@ class PaddleASRConnectionHanddler: cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - # frame window samples length and frame shift samples length - - self.win_length = int(self.model_config.window_ms * self.sample_rate) - self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + # frame window samples length and frame shift samples length - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: - self.sample_rate = self.asr_engine.executor.sample_rate + self.win_length = int(self.model_config.window_ms * + self.sample_rate) + self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model - # tokens to text - self.text_feature = self.asr_engine.executor.text_feature - # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) @@ -189,7 +190,7 @@ class PaddleASRConnectionHanddler: audio = paddle.to_tensor(audio, dtype='float32') # audio_len = paddle.to_tensor(audio_len) audio = paddle.unsqueeze(audio, axis=0) - + if self.cached_feat is None: self.cached_feat = audio else: @@ -211,7 +212,7 @@ class PaddleASRConnectionHanddler: logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) - elif "conformer2online" in self.model_type: + elif "conformer_online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 @@ -264,41 +265,43 @@ class PaddleASRConnectionHanddler: def reset(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: # for deepspeech2 - self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box) - self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box) + self.chunk_state_h_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_h_box) + self.chunk_state_c_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_c_box) self.decoder.reset_decoder(batch_size=1) - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: - # for conformer online - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - self.cached_feat = None - self.remained_wav = None - self.offset = 0 - self.num_samples = 0 - self.device = None - self.hyps = [] - self.num_frames = 0 - self.chunk_num = 0 - self.global_frame_offset = 0 - self.result_transcripts = [''] + + # for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + self.device = None + self.hyps = [] + self.num_frames = 0 + self.chunk_num = 0 + self.global_frame_offset = 0 + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: # x_chunk 是特征数据 - decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model - context = 7 # context=7 in deepspeech2 model - subsampling = 4 # subsampling=4 in deepspeech2 model + decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model + context = 7 # context=7 in deepspeech2 model + subsampling = 4 # subsampling=4 in deepspeech2 model stride = subsampling * decoding_chunk_size cached_feature_num = context - subsampling # decoding window for model - decoding_window = (decoding_chunk_size - 1) * subsampling + context - + decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") - return - + return + num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" @@ -306,14 +309,14 @@ class PaddleASRConnectionHanddler: # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( - f"frame feat num is less than {decoding_window}, please input more pcm data" + f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( - "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") @@ -334,8 +337,7 @@ class PaddleASRConnectionHanddler: self.result_transcripts = [trans_best] - self.cached_feat = self.cached_feat[:, end - - cached_feature_num:, :] + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: @@ -354,8 +356,7 @@ class PaddleASRConnectionHanddler: logger.info("start to decoce one chunk with deepspeech2 model") input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) - audio_len_handle = self.am_predictor.get_input_handle( - input_names[1]) + audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) @@ -374,11 +375,11 @@ class PaddleASRConnectionHanddler: output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) output_lens_handle = self.am_predictor.get_output_handle( - output_names[1]) + output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( - output_names[2]) + output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( - output_names[3]) + output_names[3]) self.am_predictor.run() @@ -389,7 +390,7 @@ class PaddleASRConnectionHanddler: self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one one best result: {trans_best[0]}") + logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] def advance_decoding(self, is_finished=False): @@ -500,7 +501,7 @@ class PaddleASRConnectionHanddler: def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: return - + logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: return @@ -587,7 +588,7 @@ class ASRServerExecutor(ASRExecutor): return decompressed_path def _init_from_path(self, - model_type: str='wenetspeech', + model_type: str='deepspeech2online_aishell', am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', @@ -647,7 +648,7 @@ class ASRServerExecutor(ASRExecutor): self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + elif "conformer" in model_type or "transformer" in model_type: logger.info("start to create the stream conformer asr engine") if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( @@ -711,7 +712,7 @@ class ASRServerExecutor(ASRExecutor): self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + elif "conformer" in model_type or "transformer" in model_type: model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") @@ -742,7 +743,7 @@ class ASRServerExecutor(ASRExecutor): self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + elif "conformer" in self.model_type or "transformer" in self.model_type: self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): @@ -754,7 +755,7 @@ class ASRServerExecutor(ASRExecutor): model_type (str): online model type Returns: - [type]: [description] + str: one best result """ logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: @@ -795,7 +796,7 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one one best result: {trans_best[0]}") + logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: @@ -972,7 +973,7 @@ class ASRServerExecutor(ASRExecutor): x_chunk_lens = np.array([audio_len]) return x_chunk, x_chunk_lens - elif "conformer2online" in self.model_type: + elif "conformer_online" in self.model_type: if sample_rate != self.sample_rate: logger.info(f"audio sample rate {sample_rate} is not match," @@ -1005,7 +1006,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() - logger.info("create the online asr engine instache") + logger.info("create the online asr engine instance") def init(self, config: dict) -> bool: """init engine resource