fix ds2 online edge bug, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent dcab04a799
commit babac27a79

@ -88,6 +88,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer": "conformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer": "transformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"wenetspeech": "wenetspeech":

@ -130,9 +130,10 @@ class PaddleASRConnectionHanddler:
cfg.num_proc_bsearch) cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length # frame window samples length and frame shift samples length
self.win_length = int(self.model_config.window_ms * self.win_length = int(self.model_config.window_ms / 1000 *
self.sample_rate) self.sample_rate)
self.n_shift = int(self.model_config.stride_ms * self.sample_rate) self.n_shift = int(self.model_config.stride_ms / 1000 *
self.sample_rate)
elif "conformer" in self.model_type or "transformer" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model # acoustic model
@ -158,6 +159,11 @@ class PaddleASRConnectionHanddler:
samples = np.frombuffer(samples, dtype=np.int16) samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1 assert samples.ndim == 1
# pcm16 -> pcm 32
# pcm2float will change the orignal samples,
# so we shoule do pcm2float before concatenate
samples = pcm2float(samples)
if self.remained_wav is None: if self.remained_wav is None:
self.remained_wav = samples self.remained_wav = samples
else: else:
@ -167,11 +173,9 @@ class PaddleASRConnectionHanddler:
f"The connection remain the audio samples: {self.remained_wav.shape}" f"The connection remain the audio samples: {self.remained_wav.shape}"
) )
# pcm16 -> pcm 32
samples = pcm2float(self.remained_wav)
# read audio # read audio
speech_segment = SpeechSegment.from_pcm( speech_segment = SpeechSegment.from_pcm(
samples, self.sample_rate, transcript=" ") self.remained_wav, self.sample_rate, transcript=" ")
# audio augment # audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment) self.collate_fn_test.augmentation.transform_audio(speech_segment)
@ -474,6 +478,7 @@ class PaddleASRConnectionHanddler:
self.hyps = self.searcher.get_one_best_hyps() self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1 assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num assert end >= cached_feature_num
self.cached_feat = self.cached_feat[0, end - self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0) cached_feature_num:, :].unsqueeze(0)
assert len( assert len(
@ -515,7 +520,6 @@ class PaddleASRConnectionHanddler:
return return
# assert len(hyps) == beam_size # assert len(hyps) == beam_size
paddle.save(self.encoder_out, "encoder.out")
hyp_list = [] hyp_list = []
for hyp in hyps: for hyp in hyps:
hyp_content = hyp[0] hyp_content = hyp[0]
@ -815,7 +819,7 @@ class ASRServerExecutor(ASRExecutor):
def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens):
logger.info("start to decode with advanced_decoding method") logger.info("start to decode with advanced_decoding method")
encoder_out, encoder_mask = self.decode_forward(xs) encoder_out, encoder_mask = self.encoder_forward(xs)
ctc_probs = self.model.ctc.log_softmax( ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size) encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
@ -827,7 +831,7 @@ class ASRServerExecutor(ASRExecutor):
if "attention_rescoring" in self.config.decode.decoding_method: if "attention_rescoring" in self.config.decode.decoding_method:
self.rescoring(encoder_out, xs.place) self.rescoring(encoder_out, xs.place)
def decode_forward(self, xs): def encoder_forward(self, xs):
logger.info("get the model out from the feat") logger.info("get the model out from the feat")
cfg = self.config.decode cfg = self.config.decode
decoding_chunk_size = cfg.decoding_chunk_size decoding_chunk_size = cfg.decoding_chunk_size

Loading…
Cancel
Save