Cache the TextFeaturizer instance for infer speed improvement. (#1260)

pull/1271/head
AdamBear 3 years ago committed by GitHub
parent 50752f8bc4
commit 36c9eaa437
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -174,12 +174,12 @@ class ASRExecutor(BaseExecutor):
self.config.collator.mean_std_filepath = os.path.join( self.config.collator.mean_std_filepath = os.path.join(
res_path, self.config.collator.cmvn_path) res_path, self.config.collator.cmvn_path)
self.collate_fn_test = SpeechCollator.from_config(self.config) self.collate_fn_test = SpeechCollator.from_config(self.config)
text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.collate_fn_test.feature_size self.config.model.input_dim = self.collate_fn_test.feature_size
self.config.model.output_dim = text_feature.vocab_size self.config.model.output_dim = self.text_feature.vocab_size
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
@ -187,12 +187,12 @@ class ASRExecutor(BaseExecutor):
res_path, self.config.collator.augmentation_config) res_path, self.config.collator.augmentation_config)
self.config.collator.spm_model_prefix = os.path.join( self.config.collator.spm_model_prefix = os.path.join(
res_path, self.config.collator.spm_model_prefix) res_path, self.config.collator.spm_model_prefix)
text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.config.collator.feat_dim self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = text_feature.vocab_size self.config.model.output_dim = self.text_feature.vocab_size
else: else:
raise Exception("wrong type") raise Exception("wrong type")
@ -211,6 +211,7 @@ class ASRExecutor(BaseExecutor):
model_dict = paddle.load(self.ckpt_path) model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
""" """
Input preprocess and return paddle.Tensor stored in self.input. Input preprocess and return paddle.Tensor stored in self.input.
@ -228,7 +229,7 @@ class ASRExecutor(BaseExecutor):
audio = paddle.to_tensor(audio, dtype='float32') audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len) audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0) audio = paddle.unsqueeze(audio, axis=0)
vocab_list = collate_fn_test.vocab_list # vocab_list = collate_fn_test.vocab_list
self._inputs["audio"] = audio self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}") logger.info(f"audio feat shape: {audio.shape}")
@ -274,10 +275,7 @@ class ASRExecutor(BaseExecutor):
audio_len = paddle.to_tensor(audio.shape[0]) audio_len = paddle.to_tensor(audio.shape[0])
audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self._inputs["audio"] = audio self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}") logger.info(f"audio feat shape: {audio.shape}")
@ -290,10 +288,7 @@ class ASRExecutor(BaseExecutor):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
cfg = self.config.decoding cfg = self.config.decoding
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
@ -301,7 +296,7 @@ class ASRExecutor(BaseExecutor):
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature.vocab_list, self.text_feature.vocab_list,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha, beam_alpha=cfg.alpha,
@ -316,7 +311,7 @@ class ASRExecutor(BaseExecutor):
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature=text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight, ctc_weight=cfg.ctc_weight,

Loading…
Cancel
Save