diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py index 90b7d8a1..66ea29d0 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py @@ -20,8 +20,8 @@ import paddle import soundfile from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.models.ds2 import DeepSpeech2Model from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils import mp_tools @@ -38,24 +38,24 @@ class DeepSpeech2Tester_hub(): self.args = args self.config = config self.audio_file = args.audio_file - self.collate_fn_test = SpeechCollator.from_config(config) - self._text_featurizer = TextFeaturizer( - unit_type=config.unit_type, vocab=None) - def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): - result_transcripts = self.model.decode( - audio, - audio_len, - vocab_list, - decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, - beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch) + self.preprocess_conf = config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + self.text_feature = TextFeaturizer( + unit_type=config.unit_type, + vocab=config.vocab_filepath, + spm_model_prefix=config.spm_model_prefix) + paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + decode_batch_size = cfg.decode_batch_size + self.model.decoder.init_decoder( + decode_batch_size, vocab_list, cfg.decoding_method, + cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, + cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) + result_transcripts = self.model.decode(audio, audio_len) return result_transcripts @mp_tools.rank_zero_only @@ -64,16 +64,23 @@ class DeepSpeech2Tester_hub(): self.model.eval() cfg = self.config audio_file = self.audio_file - collate_fn_test = self.collate_fn_test - audio, _ = collate_fn_test.process_utterance( - audio_file=audio_file, transcript=" ") - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - vocab_list = collate_fn_test.vocab_list + + audio, sample_rate = soundfile.read( + self.audio_file, dtype="int16", always_2d=True) + + audio = audio[:, 0] + logger.info(f"audio shape: {audio.shape}") + + # fbank + feat = self.preprocessing(audio, **self.preprocess_args) + logger.info(f"feat shape: {feat.shape}") + + audio_len = paddle.to_tensor(feat.shape[0]) + audio = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) + result_transcripts = self.compute_result_transcripts( - audio, audio_len, vocab_list, cfg.decode) + audio, audio_len, self.text_feature.vocab_list, cfg.decode) + logger.info("result_transcripts: " + result_transcripts[0]) def run_test(self): @@ -109,11 +116,9 @@ class DeepSpeech2Tester_hub(): def setup_model(self): config = self.config.clone() with UpdateConfig(config): - config.input_dim = self.collate_fn_test.feature_size - config.output_dim = self.collate_fn_test.vocab_size - + config.input_dim = config.feat_dim + config.output_dim = self.text_feature.vocab_size model = DeepSpeech2Model.from_config(config) - self.model = model def setup_checkpointer(self):