diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 4b6711327..89408786c 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -292,7 +292,8 @@ class U2STTrainer(Trainer): n_iter_processes=config.collator.num_workers, subsampling_factor=1, load_aux_output=load_transcript, - num_encs=1) + num_encs=1, + dist_sampler=True) self.valid_loader = BatchDataLoader( json_file=config.data.dev_manifest, @@ -313,7 +314,8 @@ class U2STTrainer(Trainer): n_iter_processes=config.collator.num_workers, subsampling_factor=1, load_aux_output=load_transcript, - num_encs=1) + num_encs=1, + dist_sampler=True) logger.info("Setup train/valid Dataloader!") else: # test dataset, return raw text @@ -335,7 +337,8 @@ class U2STTrainer(Trainer): augmentation_config, # aug will be off when train_mode=False n_iter_processes=config.collator.num_workers, subsampling_factor=1, - num_encs=1) + num_encs=1, + dist_sampler=False) logger.info("Setup test Dataloader!") @@ -542,7 +545,8 @@ class U2STTester(U2STTrainer): len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] rtf = num_time / (num_frames * stride_ms) - logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu)) + logger.info("RTF: %f, instance (%d), batch BELU = %f" % + (rtf, num_ins, bleu)) rtf = num_time / (num_frames * stride_ms) msg = "Test: " diff --git a/paddlespeech/s2t/io/converter.py b/paddlespeech/s2t/io/converter.py index c92ef0174..a802ac749 100644 --- a/paddlespeech/s2t/io/converter.py +++ b/paddlespeech/s2t/io/converter.py @@ -65,8 +65,9 @@ class CustomConverter(): # text data (output): (text_len, ) ys_data.append(ud) - assert xs_data[0][0] is not None, "please check Reader and Augmentation impl." - + assert xs_data[0][ + 0] is not None, "please check Reader and Augmentation impl." + xs_pad, ilens = [], [] for xs in xs_data: # perform subsampling @@ -79,22 +80,26 @@ class CustomConverter(): # perform padding and convert to tensor # currently only support real number xs_pad.append(pad_list(xs, 0).astype(self.dtype)) - + if not self.load_aux_input: xs_pad, ilens = xs_pad[0], ilens[0] break - + # NOTE: this is for multi-output (e.g., speech translation) ys_pad, olens = [], [] - + for ys in ys_data: - ys_pad.append(pad_list( - [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys], - self.ignore_id)) + ys_pad.append( + pad_list([ + np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys + ], self.ignore_id)) + + olens.append( + np.array([ + y[0].shape[0] if isinstance(y, tuple) else y.shape[0] + for y in ys + ])) - olens.append(np.array( - [y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])) - if not self.load_aux_output: ys_pad, olens = ys_pad[0], olens[0] break diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 8330b1daa..455303f70 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -18,6 +18,7 @@ from typing import Text import jsonlines import numpy as np +from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler @@ -76,7 +77,8 @@ class BatchDataLoader(): subsampling_factor: int=1, load_aux_input: bool=False, load_aux_output: bool=False, - num_encs: int=1): + num_encs: int=1, + dist_sampler: bool=False): self.json_file = json_file self.train_mode = train_mode self.use_sortagrad = sortagrad == -1 or sortagrad > 0 @@ -94,6 +96,7 @@ class BatchDataLoader(): self.n_iter_processes = n_iter_processes self.load_aux_input = load_aux_input self.load_aux_output = load_aux_output + self.dist_sampler = dist_sampler # read json data with jsonlines.open(json_file, 'r') as reader: @@ -145,11 +148,18 @@ class BatchDataLoader(): self.dataset = TransformDataset(self.minibaches, self.converter, self.reader) - self.sampler = DistributedBatchSampler( - dataset=self.dataset, - batch_size=1, - shuffle=not self.use_sortagrad if self.train_mode else False, - ) + if self.dist_sampler: + self.sampler = DistributedBatchSampler( + dataset=self.dataset, + batch_size=1, + shuffle=not self.use_sortagrad if self.train_mode else False, + drop_last=False, ) + else: + self.sampler = BatchSampler( + dataset=self.dataset, + batch_size=1, + shuffle=not self.use_sortagrad if self.train_mode else False, + drop_last=False, ) self.dataloader = DataLoader( dataset=self.dataset, @@ -181,5 +191,8 @@ class BatchDataLoader(): echo += f"subsampling_factor: {self.subsampling_factor}, " echo += f"num_encs: {self.num_encs}, " echo += f"num_workers: {self.n_iter_processes}, " + echo += f"load_aux_input: {self.load_aux_input}, " + echo += f"load_aux_output: {self.load_aux_output}, " + echo += f"dist_sampler: {self.dist_sampler}, " echo += f"file: {self.json_file}" return echo diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index fc822b21f..15ed1e4d4 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -203,12 +203,15 @@ def evaluate(args): get_tone_ids = True if args.lang == 'zh': input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] if get_tone_ids: tone_ids = input_ids["tone_ids"] elif args.lang == 'en': - input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences) + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] else: print("lang should in {'zh', 'en'}!")