|
|
|
@ -187,13 +187,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
vocab=self.config.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.spm_model_prefix)
|
|
|
|
|
self.config.decode.decoding_method = decode_method
|
|
|
|
|
self.max_len = 5000
|
|
|
|
|
if self.config.encoder_conf.get("max_len", None):
|
|
|
|
|
self.max_len = self.config.encoder_conf.max_len
|
|
|
|
|
|
|
|
|
|
logger.info(f"max len: {self.max_len}")
|
|
|
|
|
# we assumen that the subsample rate is 4 and every frame step is 40ms
|
|
|
|
|
self.max_len = 40 * self.max_len / 1000
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong type")
|
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
@ -208,6 +202,21 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
model_dict = paddle.load(self.ckpt_path)
|
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
|
|
|
|
|
|
# compute the max len limit
|
|
|
|
|
if "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
|
|
|
|
|
# in transformer like model, we may use the subsample rate cnn network
|
|
|
|
|
subsample_rate = self.model.subsampling_rate()
|
|
|
|
|
frame_shift_ms = self.config.preprocess_config.process[0][
|
|
|
|
|
'n_shift'] / self.config.preprocess_config.process[0]['fs']
|
|
|
|
|
max_len = self.model.encoder.embed.pos_enc.max_len
|
|
|
|
|
|
|
|
|
|
if self.config.encoder_conf.get("max_len", None):
|
|
|
|
|
max_len = self.config.encoder_conf.max_len
|
|
|
|
|
|
|
|
|
|
self.max_len = frame_shift_ms * max_len * subsample_rate
|
|
|
|
|
logger.info(
|
|
|
|
|
f"The asr server limit max duration len: {self.max_len}")
|
|
|
|
|
|
|
|
|
|
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
|
|
|
|
|
"""
|
|
|
|
|
Input preprocess and return paddle.Tensor stored in self.input.
|
|
|
|
|