|
|
|
@ -174,12 +174,12 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
self.config.collator.mean_std_filepath = os.path.join(
|
|
|
|
|
res_path, self.config.collator.cmvn_path)
|
|
|
|
|
self.collate_fn_test = SpeechCollator.from_config(self.config)
|
|
|
|
|
text_feature = TextFeaturizer(
|
|
|
|
|
self.text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.collator.unit_type,
|
|
|
|
|
vocab=self.config.collator.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
|
self.config.model.input_dim = self.collate_fn_test.feature_size
|
|
|
|
|
self.config.model.output_dim = text_feature.vocab_size
|
|
|
|
|
self.config.model.output_dim = self.text_feature.vocab_size
|
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
|
|
|
|
|
self.config.collator.vocab_filepath = os.path.join(
|
|
|
|
|
res_path, self.config.collator.vocab_filepath)
|
|
|
|
@ -187,12 +187,12 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
res_path, self.config.collator.augmentation_config)
|
|
|
|
|
self.config.collator.spm_model_prefix = os.path.join(
|
|
|
|
|
res_path, self.config.collator.spm_model_prefix)
|
|
|
|
|
text_feature = TextFeaturizer(
|
|
|
|
|
self.text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.collator.unit_type,
|
|
|
|
|
vocab=self.config.collator.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
|
self.config.model.input_dim = self.config.collator.feat_dim
|
|
|
|
|
self.config.model.output_dim = text_feature.vocab_size
|
|
|
|
|
self.config.model.output_dim = self.text_feature.vocab_size
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong type")
|
|
|
|
@ -211,6 +211,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
model_dict = paddle.load(self.ckpt_path)
|
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
|
|
|
|
|
"""
|
|
|
|
|
Input preprocess and return paddle.Tensor stored in self.input.
|
|
|
|
@ -228,7 +229,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
audio = paddle.to_tensor(audio, dtype='float32')
|
|
|
|
|
audio_len = paddle.to_tensor(audio_len)
|
|
|
|
|
audio = paddle.unsqueeze(audio, axis=0)
|
|
|
|
|
vocab_list = collate_fn_test.vocab_list
|
|
|
|
|
# vocab_list = collate_fn_test.vocab_list
|
|
|
|
|
self._inputs["audio"] = audio
|
|
|
|
|
self._inputs["audio_len"] = audio_len
|
|
|
|
|
logger.info(f"audio feat shape: {audio.shape}")
|
|
|
|
@ -274,10 +275,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
audio_len = paddle.to_tensor(audio.shape[0])
|
|
|
|
|
audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
|
|
|
|
|
text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.collator.unit_type,
|
|
|
|
|
vocab=self.config.collator.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
|
|
|
|
|
|
self._inputs["audio"] = audio
|
|
|
|
|
self._inputs["audio_len"] = audio_len
|
|
|
|
|
logger.info(f"audio feat shape: {audio.shape}")
|
|
|
|
@ -290,10 +288,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
"""
|
|
|
|
|
Model inference and result stored in self.output.
|
|
|
|
|
"""
|
|
|
|
|
text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.collator.unit_type,
|
|
|
|
|
vocab=self.config.collator.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
|
|
|
|
|
|
cfg = self.config.decoding
|
|
|
|
|
audio = self._inputs["audio"]
|
|
|
|
|
audio_len = self._inputs["audio_len"]
|
|
|
|
@ -301,7 +296,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
|
audio,
|
|
|
|
|
audio_len,
|
|
|
|
|
text_feature.vocab_list,
|
|
|
|
|
self.text_feature.vocab_list,
|
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
|
lang_model_path=cfg.lang_model_path,
|
|
|
|
|
beam_alpha=cfg.alpha,
|
|
|
|
@ -316,7 +311,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
|
audio,
|
|
|
|
|
audio_len,
|
|
|
|
|
text_feature=text_feature,
|
|
|
|
|
text_feature=self.text_feature,
|
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
|
beam_size=cfg.beam_size,
|
|
|
|
|
ctc_weight=cfg.ctc_weight,
|
|
|
|
|