add translate function

pull/1050/head
Junkun 3 years ago
parent 6a50211c80
commit cdd0845127

@ -26,8 +26,10 @@ from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.collator import TripletSpeechCollator from paddlespeech.s2t.io.collator import TripletSpeechCollator
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.io.sampler import SortagradBatchSampler from paddlespeech.s2t.io.sampler import SortagradBatchSampler
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
@ -423,6 +425,30 @@ class U2STTester(U2STTrainer):
trans.append(''.join([chr(i) for i in ids])) trans.append(''.join([chr(i) for i in ids]))
return trans return trans
def translate(self, audio, audio_len):
""""E2E translation from extracted audio feature"""
cfg = self.config.decoding
text_feature = self.test_loader.collate_fn.text_feature
hyps = self.model.decode(
audio,
audio_len,
text_feature=text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
return hyps
def compute_translation_metrics(self, def compute_translation_metrics(self,
utts, utts,
audio, audio,

Loading…
Cancel
Save