fix ds2 online edge bug, test=doc

pull/1704/head
xiongxinlei 3 years ago
parent dcab04a799
commit babac27a79

@ -88,6 +88,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer":
"paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":

@ -130,9 +130,10 @@ class PaddleASRConnectionHanddler:
cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length
self.win_length = int(self.model_config.window_ms *
self.win_length = int(self.model_config.window_ms / 1000 *
self.sample_rate)
self.n_shift = int(self.model_config.stride_ms * self.sample_rate)
self.n_shift = int(self.model_config.stride_ms / 1000 *
self.sample_rate)
elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model
@ -158,6 +159,11 @@ class PaddleASRConnectionHanddler:
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
# pcm16 -> pcm 32
# pcm2float will change the orignal samples,
# so we shoule do pcm2float before concatenate
samples = pcm2float(samples)
if self.remained_wav is None:
self.remained_wav = samples
else:
@ -167,11 +173,9 @@ class PaddleASRConnectionHanddler:
f"The connection remain the audio samples: {self.remained_wav.shape}"
)
# pcm16 -> pcm 32
samples = pcm2float(self.remained_wav)
# read audio
speech_segment = SpeechSegment.from_pcm(
samples, self.sample_rate, transcript=" ")
self.remained_wav, self.sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
@ -474,6 +478,7 @@ class PaddleASRConnectionHanddler:
self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num
self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0)
assert len(
@ -515,7 +520,6 @@ class PaddleASRConnectionHanddler:
return
# assert len(hyps) == beam_size
paddle.save(self.encoder_out, "encoder.out")
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
@ -815,7 +819,7 @@ class ASRServerExecutor(ASRExecutor):
def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens):
logger.info("start to decode with advanced_decoding method")
encoder_out, encoder_mask = self.decode_forward(xs)
encoder_out, encoder_mask = self.encoder_forward(xs)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
@ -827,7 +831,7 @@ class ASRServerExecutor(ASRExecutor):
if "attention_rescoring" in self.config.decode.decoding_method:
self.rescoring(encoder_out, xs.place)
def decode_forward(self, xs):
def encoder_forward(self, xs):
logger.info("get the model out from the feat")
cfg = self.config.decode
decoding_chunk_size = cfg.decoding_chunk_size

Loading…
Cancel
Save