|
|
@ -26,8 +26,10 @@ from paddle import distributed as dist
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
|
|
|
|
from paddlespeech.s2t.io.collator import SpeechCollator
|
|
|
|
from paddlespeech.s2t.io.collator import SpeechCollator
|
|
|
|
from paddlespeech.s2t.io.collator import TripletSpeechCollator
|
|
|
|
from paddlespeech.s2t.io.collator import TripletSpeechCollator
|
|
|
|
|
|
|
|
from paddlespeech.s2t.io.dataloader import BatchDataLoader
|
|
|
|
from paddlespeech.s2t.io.dataset import ManifestDataset
|
|
|
|
from paddlespeech.s2t.io.dataset import ManifestDataset
|
|
|
|
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
|
|
|
|
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
|
|
|
|
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
|
|
|
|
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
|
|
|
@ -423,6 +425,30 @@ class U2STTester(U2STTrainer):
|
|
|
|
trans.append(''.join([chr(i) for i in ids]))
|
|
|
|
trans.append(''.join([chr(i) for i in ids]))
|
|
|
|
return trans
|
|
|
|
return trans
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate(self, audio, audio_len):
|
|
|
|
|
|
|
|
""""E2E translation from extracted audio feature"""
|
|
|
|
|
|
|
|
cfg = self.config.decoding
|
|
|
|
|
|
|
|
text_feature = self.test_loader.collate_fn.text_feature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyps = self.model.decode(
|
|
|
|
|
|
|
|
audio,
|
|
|
|
|
|
|
|
audio_len,
|
|
|
|
|
|
|
|
text_feature=text_feature,
|
|
|
|
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
|
|
|
|
lang_model_path=cfg.lang_model_path,
|
|
|
|
|
|
|
|
beam_alpha=cfg.alpha,
|
|
|
|
|
|
|
|
beam_beta=cfg.beta,
|
|
|
|
|
|
|
|
beam_size=cfg.beam_size,
|
|
|
|
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
|
|
|
|
num_processes=cfg.num_proc_bsearch,
|
|
|
|
|
|
|
|
ctc_weight=cfg.ctc_weight,
|
|
|
|
|
|
|
|
word_reward=cfg.word_reward,
|
|
|
|
|
|
|
|
decoding_chunk_size=cfg.decoding_chunk_size,
|
|
|
|
|
|
|
|
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
|
|
|
|
|
|
|
|
simulate_streaming=cfg.simulate_streaming)
|
|
|
|
|
|
|
|
return hyps
|
|
|
|
|
|
|
|
|
|
|
|
def compute_translation_metrics(self,
|
|
|
|
def compute_translation_metrics(self,
|
|
|
|
utts,
|
|
|
|
utts,
|
|
|
|
audio,
|
|
|
|
audio,
|
|
|
|