update the max len compute method, test=doc

pull/1913/head
xiongxinlei 3 years ago
parent 0ea39f837b
commit b1ef434983

@ -187,13 +187,7 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
self.max_len = 5000
if self.config.encoder_conf.get("max_len", None):
self.max_len = self.config.encoder_conf.max_len
logger.info(f"max len: {self.max_len}")
# we assumen that the subsample rate is 4 and every frame step is 40ms
self.max_len = 40 * self.max_len / 1000
else: else:
raise Exception("wrong type") raise Exception("wrong type")
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
@ -208,6 +202,21 @@ class ASRExecutor(BaseExecutor):
model_dict = paddle.load(self.ckpt_path) model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
# compute the max len limit
if "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
# in transformer like model, we may use the subsample rate cnn network
subsample_rate = self.model.subsampling_rate()
frame_shift_ms = self.config.preprocess_config.process[0][
'n_shift'] / self.config.preprocess_config.process[0]['fs']
max_len = self.model.encoder.embed.pos_enc.max_len
if self.config.encoder_conf.get("max_len", None):
max_len = self.config.encoder_conf.max_len
self.max_len = frame_shift_ms * max_len * subsample_rate
logger.info(
f"The asr server limit max duration len: {self.max_len}")
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
""" """
Input preprocess and return paddle.Tensor stored in self.input. Input preprocess and return paddle.Tensor stored in self.input.

@ -332,7 +332,7 @@ class BaseEncoder(nn.Layer):
# fake mask, just for jit script and compatibility with `forward` api # fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) masks = masks.unsqueeze(1)
return ys, masks, offset return ys, masks
class TransformerEncoder(BaseEncoder): class TransformerEncoder(BaseEncoder):

Loading…
Cancel
Save