diff --git a/paddlespeech/t2s/exps/tacotron2/preprocess.py b/paddlespeech/t2s/exps/tacotron2/preprocess.py index 96eb64616..46b725916 100644 --- a/paddlespeech/t2s/exps/tacotron2/preprocess.py +++ b/paddlespeech/t2s/exps/tacotron2/preprocess.py @@ -228,9 +228,9 @@ def main(): if args.dataset == "baker": wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) - # split data into 3 sections, the max number of dev/test is 10% or 100 - num_dev = min(int(len(wav_files) * 0.1), 100) - num_train = len(wav_files) - num_dev * 2 + # split data into 3 sections + num_train = 9800 + num_dev = 100 train_wav_files = wav_files[:num_train] dev_wav_files = wav_files[num_train:num_train + num_dev] test_wav_files = wav_files[num_train + num_dev:]