diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 52d3c3b7..034463fe 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -26,8 +26,10 @@ from paddle import distributed as dist from paddle.io import DataLoader from yacs.config import CfgNode +from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import TripletSpeechCollator +from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.io.sampler import SortagradBatchSampler from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler @@ -423,6 +425,30 @@ class U2STTester(U2STTrainer): trans.append(''.join([chr(i) for i in ids])) return trans + def translate(self, audio, audio_len): + """"E2E translation from extracted audio feature""" + cfg = self.config.decoding + text_feature = self.test_loader.collate_fn.text_feature + + hyps = self.model.decode( + audio, + audio_len, + text_feature=text_feature, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch, + ctc_weight=cfg.ctc_weight, + word_reward=cfg.word_reward, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + return hyps + def compute_translation_metrics(self, utts, audio,