From babac27a7943b5be254afab8af09e909b0d3151c Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 18:14:30 +0800 Subject: [PATCH] fix ds2 online edge bug, test=doc --- paddlespeech/cli/asr/pretrained_models.py | 2 ++ .../server/engine/asr/online/asr_engine.py | 20 +++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/paddlespeech/cli/asr/pretrained_models.py b/paddlespeech/cli/asr/pretrained_models.py index a16c4750d..cc52c751b 100644 --- a/paddlespeech/cli/asr/pretrained_models.py +++ b/paddlespeech/cli/asr/pretrained_models.py @@ -88,6 +88,8 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", + "conformer_online": + "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 3c2b066c9..4d15d93b5 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -130,9 +130,10 @@ class PaddleASRConnectionHanddler: cfg.num_proc_bsearch) # frame window samples length and frame shift samples length - self.win_length = int(self.model_config.window_ms * + self.win_length = int(self.model_config.window_ms / 1000 * self.sample_rate) - self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + self.n_shift = int(self.model_config.stride_ms / 1000 * + self.sample_rate) elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model @@ -158,6 +159,11 @@ class PaddleASRConnectionHanddler: samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 + # pcm16 -> pcm 32 + # pcm2float will change the orignal samples, + # so we shoule do pcm2float before concatenate + samples = pcm2float(samples) + if self.remained_wav is None: self.remained_wav = samples else: @@ -167,11 +173,9 @@ class PaddleASRConnectionHanddler: f"The connection remain the audio samples: {self.remained_wav.shape}" ) - # pcm16 -> pcm 32 - samples = pcm2float(self.remained_wav) # read audio speech_segment = SpeechSegment.from_pcm( - samples, self.sample_rate, transcript=" ") + self.remained_wav, self.sample_rate, transcript=" ") # audio augment self.collate_fn_test.augmentation.transform_audio(speech_segment) @@ -474,6 +478,7 @@ class PaddleASRConnectionHanddler: self.hyps = self.searcher.get_one_best_hyps() assert self.cached_feat.shape[0] == 1 assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0, end - cached_feature_num:, :].unsqueeze(0) assert len( @@ -515,7 +520,6 @@ class PaddleASRConnectionHanddler: return # assert len(hyps) == beam_size - paddle.save(self.encoder_out, "encoder.out") hyp_list = [] for hyp in hyps: hyp_content = hyp[0] @@ -815,7 +819,7 @@ class ASRServerExecutor(ASRExecutor): def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): logger.info("start to decode with advanced_decoding method") - encoder_out, encoder_mask = self.decode_forward(xs) + encoder_out, encoder_mask = self.encoder_forward(xs) ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) @@ -827,7 +831,7 @@ class ASRServerExecutor(ASRExecutor): if "attention_rescoring" in self.config.decode.decoding_method: self.rescoring(encoder_out, xs.place) - def decode_forward(self, xs): + def encoder_forward(self, xs): logger.info("get the model out from the feat") cfg = self.config.decode decoding_chunk_size = cfg.decoding_chunk_size