From 4c3e57a23ccbe085f014cf31163b18dd70cac2a3 Mon Sep 17 00:00:00 2001
From: TianYuan <white-sky@qq.com>
Date: Tue, 25 Jan 2022 06:33:24 +0000
Subject: [PATCH] align preprocess of wavernn, test=tts

---
 examples/csmsc/voc6/local/preprocess.sh       |  48 +++-
 examples/csmsc/voc6/local/synthesize.sh       |   3 +-
 examples/csmsc/voc6/local/train.sh            |   6 +-
 examples/csmsc/voc6/run.sh                    |   7 +-
 paddlespeech/t2s/datasets/vocoder_batch_fn.py | 216 +++++++++---------
 paddlespeech/t2s/exps/wavernn/preprocess.py   | 157 -------------
 paddlespeech/t2s/exps/wavernn/synthesize.py   |  61 +++--
 paddlespeech/t2s/exps/wavernn/train.py        |  36 ++-
 .../t2s/models/wavernn/wavernn_updater.py     |  36 ++-
 9 files changed, 250 insertions(+), 320 deletions(-)
 delete mode 100644 paddlespeech/t2s/exps/wavernn/preprocess.py

diff --git a/examples/csmsc/voc6/local/preprocess.sh b/examples/csmsc/voc6/local/preprocess.sh
index 064aea557..2dcc39ac7 100755
--- a/examples/csmsc/voc6/local/preprocess.sh
+++ b/examples/csmsc/voc6/local/preprocess.sh
@@ -6,10 +6,50 @@ stop_stage=100
 config_path=$1
 
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
-    python3 ${BIN_DIR}/preprocess.py \
-        --input=~/datasets/BZNSYP/ \
-        --output=dump \
-        --dataset=csmsc \
+    # get durations from MFA's result
+    echo "Generate durations.txt from MFA results ..."
+    python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
+        --inputdir=./baker_alignment_tone \
+        --output=durations.txt \
+        --config=${config_path}
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    # extract features
+    echo "Extract features ..."
+    python3 ${BIN_DIR}/../gan_vocoder/preprocess.py \
+        --rootdir=~/datasets/BZNSYP/ \
+        --dataset=baker \
+        --dumpdir=dump \
+        --dur-file=durations.txt \
         --config=${config_path} \
+        --cut-sil=True \
         --num-cpu=20
 fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    # get features' stats(mean and std)
+    echo "Get features' stats ..."
+    python3 ${MAIN_ROOT}/utils/compute_statistics.py \
+        --metadata=dump/train/raw/metadata.jsonl \
+        --field-name="feats"
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    # normalize, dev and test should use train's stats
+    echo "Normalize ..."
+   
+    python3 ${BIN_DIR}/../gan_vocoder/normalize.py \
+        --metadata=dump/train/raw/metadata.jsonl \
+        --dumpdir=dump/train/norm \
+        --stats=dump/train/feats_stats.npy
+    python3 ${BIN_DIR}/../gan_vocoder/normalize.py \
+        --metadata=dump/dev/raw/metadata.jsonl \
+        --dumpdir=dump/dev/norm \
+        --stats=dump/train/feats_stats.npy
+    
+    python3 ${BIN_DIR}/../gan_vocoder/normalize.py \
+        --metadata=dump/test/raw/metadata.jsonl \
+        --dumpdir=dump/test/norm \
+        --stats=dump/train/feats_stats.npy
+fi
diff --git a/examples/csmsc/voc6/local/synthesize.sh b/examples/csmsc/voc6/local/synthesize.sh
index 876c8444e..7f0cbe48c 100755
--- a/examples/csmsc/voc6/local/synthesize.sh
+++ b/examples/csmsc/voc6/local/synthesize.sh
@@ -3,12 +3,11 @@
 config_path=$1
 train_output_path=$2
 ckpt_name=$3
-test_input=$4
 
 FLAGS_allocator_strategy=naive_best_fit \
 FLAGS_fraction_of_gpu_memory_to_use=0.01 \
 python3 ${BIN_DIR}/synthesize.py \
     --config=${config_path} \
     --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
-    --input=${test_input} \
+    --test-metadata=dump/test/norm/metadata.jsonl \
     --output-dir=${train_output_path}/test
diff --git a/examples/csmsc/voc6/local/train.sh b/examples/csmsc/voc6/local/train.sh
index 900450cdd..9695631ef 100755
--- a/examples/csmsc/voc6/local/train.sh
+++ b/examples/csmsc/voc6/local/train.sh
@@ -2,8 +2,12 @@
 
 config_path=$1
 train_output_path=$2
+
+FLAGS_cudnn_exhaustive_search=true \
+FLAGS_conv_workspace_size_limit=4000 \
 python ${BIN_DIR}/train.py \
+    --train-metadata=dump/train/norm/metadata.jsonl \
+    --dev-metadata=dump/dev/norm/metadata.jsonl \
     --config=${config_path} \
-    --data=dump/ \
     --output-dir=${train_output_path} \
     --ngpu=1
diff --git a/examples/csmsc/voc6/run.sh b/examples/csmsc/voc6/run.sh
index bd32e3d2e..5f754fff3 100755
--- a/examples/csmsc/voc6/run.sh
+++ b/examples/csmsc/voc6/run.sh
@@ -9,7 +9,7 @@ stop_stage=100
 
 conf_path=conf/default.yaml
 train_output_path=exp/default
-test_input=dump/mel_test
+test_input=dump/dump_gta_test
 ckpt_name=snapshot_iter_100000.pdz
 
 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
@@ -25,9 +25,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
 fi
 
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
-    # copy some test mels from dump
-    mkdir -p ${test_input}
-    cp -r dump/mel/00995*.npy ${test_input}
     # synthesize
-    CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ${test_input}|| exit -1
+    CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
 fi
diff --git a/paddlespeech/t2s/datasets/vocoder_batch_fn.py b/paddlespeech/t2s/datasets/vocoder_batch_fn.py
index 496bf902a..b1d22db97 100644
--- a/paddlespeech/t2s/datasets/vocoder_batch_fn.py
+++ b/paddlespeech/t2s/datasets/vocoder_batch_fn.py
@@ -12,11 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import math
-from pathlib import Path
 
 import numpy as np
 import paddle
-from paddle.io import Dataset
 
 
 def label_2_float(x, bits):
@@ -44,102 +42,6 @@ def decode_mu_law(y, mu, from_labels=True):
     return x
 
 
-class WaveRNNDataset(Dataset):
-    """A simple dataset adaptor for the processed ljspeech dataset."""
-
-    def __init__(self, root):
-        self.root = Path(root).expanduser()
-
-        records = []
-
-        with open(self.root / "metadata.csv", 'r') as rf:
-
-            for line in rf:
-                name = line.split("\t")[0]
-                mel_path = str(self.root / "mel" / (str(name) + ".npy"))
-                wav_path = str(self.root / "wav" / (str(name) + ".npy"))
-                records.append((mel_path, wav_path))
-
-        self.records = records
-
-    def __getitem__(self, i):
-        mel_name, wav_name = self.records[i]
-        mel = np.load(mel_name)
-        wav = np.load(wav_name)
-        return mel, wav
-
-    def __len__(self):
-        return len(self.records)
-
-
-class WaveRNNClip(object):
-    def __init__(self,
-                 mode: str='RAW',
-                 batch_max_steps: int=4500,
-                 hop_size: int=300,
-                 aux_context_window: int=2,
-                 bits: int=9):
-        self.mode = mode
-        self.mel_win = batch_max_steps // hop_size + 2 * aux_context_window
-        self.batch_max_steps = batch_max_steps
-        self.hop_size = hop_size
-        self.aux_context_window = aux_context_window
-        if self.mode == 'MOL':
-            self.bits = 16
-        else:
-            self.bits = bits
-
-    def __call__(self, batch):
-        # batch: [mel, quant]
-        # voc_pad = 2  this will pad the input so that the resnet can 'see' wider than input length
-        # max_offsets = n_frames - 2 - (mel_win + 2 * hp.voc_pad) = n_frames - 15
-        max_offsets = [
-            x[0].shape[-1] - 2 - (self.mel_win + 2 * self.aux_context_window)
-            for x in batch
-        ]
-        # the slice point of mel selecting randomly 
-        mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
-        # the slice point of wav selecting randomly, which is behind 2(=pad) frames 
-        sig_offsets = [(offset + self.aux_context_window) * self.hop_size
-                       for offset in mel_offsets]
-        # mels.sape[1] = voc_seq_len // hop_length + 2 * voc_pad
-        mels = [
-            x[0][:, mel_offsets[i]:mel_offsets[i] + self.mel_win]
-            for i, x in enumerate(batch)
-        ]
-        # label.shape[1] = voc_seq_len + 1
-        labels = [
-            x[1][sig_offsets[i]:sig_offsets[i] + self.batch_max_steps + 1]
-            for i, x in enumerate(batch)
-        ]
-
-        mels = np.stack(mels).astype(np.float32)
-        labels = np.stack(labels).astype(np.int64)
-
-        mels = paddle.to_tensor(mels)
-        labels = paddle.to_tensor(labels, dtype='int64')
-
-        # x is input, y is label
-        x = labels[:, :self.batch_max_steps]
-        y = labels[:, 1:]
-        '''
-        mode = RAW:
-            mu_law = True:
-                quant: bits = 9   0, 1, 2, ..., 509, 510, 511  int
-            mu_law = False
-                quant bits = 9    [0, 511]  float
-        mode = MOL:
-            quant: bits = 16  [0. 65536]  float
-        '''
-        # x should be normalizes in.[0, 1] in RAW mode
-        x = label_2_float(paddle.cast(x, dtype='float32'), self.bits)
-        # y should be normalizes in.[0, 1] in MOL mode
-        if self.mode == 'MOL':
-            y = label_2_float(paddle.cast(y, dtype='float32'), self.bits)
-
-        return x, y, mels
-
-
 class Clip(object):
     """Collate functor for training vocoders.
     """
@@ -174,7 +76,7 @@ class Clip(object):
         self.end_offset = -(self.batch_max_frames + aux_context_window)
         self.mel_threshold = self.batch_max_frames + 2 * aux_context_window
 
-    def __call__(self, examples):
+    def __call__(self, batch):
         """Convert into batch tensors.
 
         Parameters
@@ -192,11 +94,11 @@ class Clip(object):
 
         """
         # check length
-        examples = [
-            self._adjust_length(b['wave'], b['feats']) for b in examples
+        batch = [
+            self._adjust_length(b['wave'], b['feats']) for b in batch
             if b['feats'].shape[0] > self.mel_threshold
         ]
-        xs, cs = [b[0] for b in examples], [b[1] for b in examples]
+        xs, cs = [b[0] for b in batch], [b[1] for b in batch]
 
         # make batch with random cut
         c_lengths = [c.shape[0] for c in cs]
@@ -214,7 +116,7 @@ class Clip(object):
         c_batch = np.stack(
             [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)])
 
-        # convert each batch to tensor, asuume that each item in batch has the same length
+        # convert each batch to tensor, assume that each item in batch has the same length
         y_batch = paddle.to_tensor(
             y_batch, dtype=paddle.float32).unsqueeze(1)  # (B, 1, T)
         c_batch = paddle.to_tensor(
@@ -245,3 +147,111 @@ class Clip(object):
             0] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[0]})"
 
         return x, c
+
+
+class WaveRNNClip(Clip):
+    def __init__(self,
+                 mode: str='RAW',
+                 batch_max_steps: int=4500,
+                 hop_size: int=300,
+                 aux_context_window: int=2,
+                 bits: int=9,
+                 mu_law: bool=True):
+        self.mode = mode
+        self.mel_win = batch_max_steps // hop_size + 2 * aux_context_window
+        self.batch_max_steps = batch_max_steps
+        self.hop_size = hop_size
+        self.aux_context_window = aux_context_window
+        self.mu_law = mu_law
+        self.batch_max_frames = batch_max_steps // hop_size
+        self.mel_threshold = self.batch_max_frames + 2 * aux_context_window
+        if self.mode == 'MOL':
+            self.bits = 16
+        else:
+            self.bits = bits
+
+    def to_quant(self, wav):
+        if self.mode == 'RAW':
+            if self.mu_law:
+                quant = encode_mu_law(wav, mu=2**self.bits)
+            else:
+                quant = float_2_label(wav, bits=self.bits)
+        elif self.mode == 'MOL':
+            quant = float_2_label(wav, bits=16)
+        quant = quant.astype(np.int64)
+        return quant
+
+    def __call__(self, batch):
+        # voc_pad = 2  this will pad the input so that the resnet can 'see' wider than input length
+        # max_offsets = n_frames - 2 - (mel_win + 2 * hp.voc_pad) = n_frames - 15
+        """Convert into batch tensors.
+
+        Parameters
+        ----------
+        batch : list
+            list of tuple of the pair of audio and features. 
+            Audio shape (T, ), features shape(T', C).
+
+        Returns
+        ----------
+        Tensor
+            Auxiliary feature batch (B, C, T'), where
+            T = (T' - 2 * aux_context_window) * hop_size.
+        Tensor
+            Target signal batch (B, 1, T).
+
+        """
+        # check length
+        batch = [
+            self._adjust_length(b['wave'], b['feats']) for b in batch
+            if b['feats'].shape[0] > self.mel_threshold
+        ]
+        wav, mel = [b[0] for b in batch], [b[1] for b in batch]
+        # mel 此处需要转置
+        mel = [x.T for x in mel]
+        max_offsets = [
+            x.shape[-1] - 2 - (self.mel_win + 2 * self.aux_context_window)
+            for x in mel
+        ]
+        # the slice point of mel selecting randomly 
+        mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
+        # the slice point of wav selecting randomly, which is behind 2(=pad) frames 
+        sig_offsets = [(offset + self.aux_context_window) * self.hop_size
+                       for offset in mel_offsets]
+        # mels.shape[1] = voc_seq_len // hop_length + 2 * voc_pad
+        mels = [
+            x[:, mel_offsets[i]:mel_offsets[i] + self.mel_win]
+            for i, x in enumerate(mel)
+        ]
+        # label.shape[1] = voc_seq_len + 1
+        wav = [self.to_quant(x) for x in wav]
+
+        labels = [
+            x[sig_offsets[i]:sig_offsets[i] + self.batch_max_steps + 1]
+            for i, x in enumerate(wav)
+        ]
+
+        mels = np.stack(mels).astype(np.float32)
+        labels = np.stack(labels).astype(np.int64)
+
+        mels = paddle.to_tensor(mels)
+        labels = paddle.to_tensor(labels, dtype='int64')
+        # x is input, y is label
+        x = labels[:, :self.batch_max_steps]
+        y = labels[:, 1:]
+        '''
+        mode = RAW:
+            mu_law = True:
+                quant: bits = 9   0, 1, 2, ..., 509, 510, 511  int
+            mu_law = False
+                quant bits = 9    [0, 511]  float
+        mode = MOL:
+            quant: bits = 16  [0. 65536]  float
+        '''
+        # x should be normalizes in.[0, 1] in RAW mode
+        x = label_2_float(paddle.cast(x, dtype='float32'), self.bits)
+        # y should be normalizes in.[0, 1] in MOL mode
+        if self.mode == 'MOL':
+            y = label_2_float(paddle.cast(y, dtype='float32'), self.bits)
+
+        return x, y, mels
diff --git a/paddlespeech/t2s/exps/wavernn/preprocess.py b/paddlespeech/t2s/exps/wavernn/preprocess.py
deleted file mode 100644
index a26c6702a..000000000
--- a/paddlespeech/t2s/exps/wavernn/preprocess.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import argparse
-import os
-from multiprocessing import cpu_count
-from multiprocessing import Pool
-from pathlib import Path
-
-import librosa
-import numpy as np
-import pandas as pd
-import tqdm
-import yaml
-from yacs.config import CfgNode
-
-from paddlespeech.t2s.data.get_feats import LogMelFBank
-from paddlespeech.t2s.datasets import CSMSCMetaData
-from paddlespeech.t2s.datasets import LJSpeechMetaData
-from paddlespeech.t2s.datasets.vocoder_batch_fn import encode_mu_law
-from paddlespeech.t2s.datasets.vocoder_batch_fn import float_2_label
-
-
-class Transform(object):
-    def __init__(self, output_dir: Path, config):
-        self.fs = config.fs
-        self.peak_norm = config.peak_norm
-        self.bits = config.model.bits
-        self.mode = config.model.mode
-        self.mu_law = config.mu_law
-
-        self.wav_dir = output_dir / "wav"
-        self.mel_dir = output_dir / "mel"
-        self.wav_dir.mkdir(exist_ok=True)
-        self.mel_dir.mkdir(exist_ok=True)
-
-        self.mel_extractor = LogMelFBank(
-            sr=config.fs,
-            n_fft=config.n_fft,
-            hop_length=config.n_shift,
-            win_length=config.win_length,
-            window=config.window,
-            n_mels=config.n_mels,
-            fmin=config.fmin,
-            fmax=config.fmax)
-
-        if self.mode != 'RAW' and self.mode != 'MOL':
-            raise RuntimeError('Unknown mode value - ', self.mode)
-
-    def __call__(self, example):
-        wav_path, _, _ = example
-
-        base_name = os.path.splitext(os.path.basename(wav_path))[0]
-        # print("self.sample_rate:",self.sample_rate)
-        wav, _ = librosa.load(wav_path, sr=self.fs)
-        peak = np.abs(wav).max()
-        if self.peak_norm or peak > 1.0:
-            wav /= peak
-
-        mel = self.mel_extractor.get_log_mel_fbank(wav).T
-        if self.mode == 'RAW':
-            if self.mu_law:
-                quant = encode_mu_law(wav, mu=2**self.bits)
-            else:
-                quant = float_2_label(wav, bits=self.bits)
-        elif self.mode == 'MOL':
-            quant = float_2_label(wav, bits=16)
-
-        mel = mel.astype(np.float32)
-        audio = quant.astype(np.int64)
-
-        np.save(str(self.wav_dir / base_name), audio)
-        np.save(str(self.mel_dir / base_name), mel)
-
-        return base_name, mel.shape[-1], audio.shape[-1]
-
-
-def create_dataset(config,
-                   input_dir,
-                   output_dir,
-                   nprocs: int=1,
-                   dataset_type: str="ljspeech"):
-    input_dir = Path(input_dir).expanduser()
-    '''
-    LJSpeechMetaData.records: [filename, normalized text, speaker name(ljspeech)]
-    CSMSCMetaData.records: [filename, normalized text, pinyin]
-    '''
-    if dataset_type == 'ljspeech':
-        dataset = LJSpeechMetaData(input_dir)
-    else:
-        dataset = CSMSCMetaData(input_dir)
-    output_dir = Path(output_dir).expanduser()
-    output_dir.mkdir(exist_ok=True)
-
-    transform = Transform(output_dir, config)
-
-    file_names = []
-
-    pool = Pool(processes=nprocs)
-
-    for info in tqdm.tqdm(pool.imap(transform, dataset), total=len(dataset)):
-        base_name, mel_len, audio_len = info
-        file_names.append((base_name, mel_len, audio_len))
-
-    meta_data = pd.DataFrame.from_records(file_names)
-    meta_data.to_csv(
-        str(output_dir / "metadata.csv"), sep="\t", index=None, header=None)
-    print("saved meta data in to {}".format(
-        os.path.join(output_dir, "metadata.csv")))
-
-    print("Done!")
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="create dataset")
-    parser.add_argument(
-        "--config", type=str, help="config file to overwrite default config.")
-
-    parser.add_argument(
-        "--input", type=str, help="path of the ljspeech dataset")
-    parser.add_argument(
-        "--output", type=str, help="path to save output dataset")
-    parser.add_argument(
-        "--num-cpu",
-        type=int,
-        default=cpu_count() // 2,
-        help="number of process.")
-    parser.add_argument(
-        "--dataset",
-        type=str,
-        default="ljspeech",
-        help="The dataset to preprocess, ljspeech or csmsc")
-
-    args = parser.parse_args()
-
-    with open(args.config, 'rt') as f:
-        config = CfgNode(yaml.safe_load(f))
-
-    if args.dataset != "ljspeech" and args.dataset != "csmsc":
-        raise RuntimeError('Unknown dataset - ', args.dataset)
-
-    create_dataset(
-        config,
-        input_dir=args.input,
-        output_dir=args.output,
-        nprocs=args.num_cpu,
-        dataset_type=args.dataset)
diff --git a/paddlespeech/t2s/exps/wavernn/synthesize.py b/paddlespeech/t2s/exps/wavernn/synthesize.py
index e08c52b60..61723e039 100644
--- a/paddlespeech/t2s/exps/wavernn/synthesize.py
+++ b/paddlespeech/t2s/exps/wavernn/synthesize.py
@@ -15,13 +15,16 @@ import argparse
 import os
 from pathlib import Path
 
+import jsonlines
 import numpy as np
 import paddle
 import soundfile as sf
 import yaml
 from paddle import distributed as dist
+from timer import timer
 from yacs.config import CfgNode
 
+from paddlespeech.t2s.datasets.data_table import DataTable
 from paddlespeech.t2s.models.wavernn import WaveRNN
 
 
@@ -30,10 +33,7 @@ def main():
 
     parser.add_argument("--config", type=str, help="GANVocoder config file.")
     parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
-    parser.add_argument(
-        "--input",
-        type=str,
-        help="path of directory containing mel spectrogram (in .npy format)")
+    parser.add_argument("--test-metadata", type=str, help="dev data.")
     parser.add_argument("--output-dir", type=str, help="output dir.")
     parser.add_argument(
         "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
@@ -65,24 +65,43 @@ def main():
 
     model.eval()
 
-    mel_dir = Path(args.input).expanduser()
-    output_dir = Path(args.output_dir).expanduser()
+    with jsonlines.open(args.test_metadata, 'r') as reader:
+        metadata = list(reader)
+    test_dataset = DataTable(
+        metadata,
+        fields=['utt_id', 'feats'],
+        converters={
+            'utt_id': None,
+            'feats': np.load,
+        })
+    output_dir = Path(args.output_dir)
     output_dir.mkdir(parents=True, exist_ok=True)
-    for file_path in sorted(mel_dir.iterdir()):
-        mel = np.load(str(file_path))
-        mel = paddle.to_tensor(mel)
-        mel = mel.transpose([1, 0])
-        # input shape is (T', C_aux)
-        audio = model.generate(
-            c=mel,
-            batched=config.inference.gen_batched,
-            target=config.inference.target,
-            overlap=config.inference.overlap,
-            mu_law=config.mu_law,
-            gen_display=True)
-        audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav")
-        sf.write(audio_path, audio.numpy(), samplerate=config.fs)
-        print("[synthesize] {} -> {}".format(file_path, audio_path))
+
+    N = 0
+    T = 0
+    for example in test_dataset:
+        utt_id = example['utt_id']
+        mel = example['feats']
+        mel = paddle.to_tensor(mel)  # (T, C)
+        with timer() as t:
+            with paddle.no_grad():
+                wav = model.generate(
+                    c=mel,
+                    batched=config.inference.gen_batched,
+                    target=config.inference.target,
+                    overlap=config.inference.overlap,
+                    mu_law=config.mu_law,
+                    gen_display=True)
+            wav = wav.numpy()
+            N += wav.size
+            T += t.elapse
+            speed = wav.size / t.elapse
+            rtf = config.fs / speed
+        print(
+            f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+        )
+        sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
+    print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
 
 
 if __name__ == "__main__":
diff --git a/paddlespeech/t2s/exps/wavernn/train.py b/paddlespeech/t2s/exps/wavernn/train.py
index d7bfc49bf..aec745f76 100644
--- a/paddlespeech/t2s/exps/wavernn/train.py
+++ b/paddlespeech/t2s/exps/wavernn/train.py
@@ -16,6 +16,8 @@ import os
 import shutil
 from pathlib import Path
 
+import jsonlines
+import numpy as np
 import paddle
 import yaml
 from paddle import DataParallel
@@ -25,9 +27,8 @@ from paddle.io import DistributedBatchSampler
 from paddle.optimizer import Adam
 from yacs.config import CfgNode
 
-from paddlespeech.t2s.data import dataset
+from paddlespeech.t2s.datasets.data_table import DataTable
 from paddlespeech.t2s.datasets.vocoder_batch_fn import WaveRNNClip
-from paddlespeech.t2s.datasets.vocoder_batch_fn import WaveRNNDataset
 from paddlespeech.t2s.models.wavernn import WaveRNN
 from paddlespeech.t2s.models.wavernn import WaveRNNEvaluator
 from paddlespeech.t2s.models.wavernn import WaveRNNUpdater
@@ -56,10 +57,26 @@ def train_sp(args, config):
         f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
     )
 
-    wavernn_dataset = WaveRNNDataset(args.data)
-
-    train_dataset, dev_dataset = dataset.split(
-        wavernn_dataset, len(wavernn_dataset) - config.valid_size)
+    # construct dataset for training and validation
+    with jsonlines.open(args.train_metadata, 'r') as reader:
+        train_metadata = list(reader)
+    train_dataset = DataTable(
+        data=train_metadata,
+        fields=["wave", "feats"],
+        converters={
+            "wave": np.load,
+            "feats": np.load,
+        }, )
+
+    with jsonlines.open(args.dev_metadata, 'r') as reader:
+        dev_metadata = list(reader)
+    dev_dataset = DataTable(
+        data=dev_metadata,
+        fields=["wave", "feats"],
+        converters={
+            "wave": np.load,
+            "feats": np.load,
+        }, )
 
     batch_fn = WaveRNNClip(
         mode=config.model.mode,
@@ -92,7 +109,9 @@ def train_sp(args, config):
         collate_fn=batch_fn,
         batch_sampler=dev_sampler,
         num_workers=config.num_workers)
+
     valid_generate_loader = DataLoader(dev_dataset, batch_size=1)
+
     print("dataloaders done!")
 
     model = WaveRNN(
@@ -160,10 +179,11 @@ def train_sp(args, config):
 def main():
     # parse args and config and redirect to train_sp
 
-    parser = argparse.ArgumentParser(description="Train a WaveRNN model.")
+    parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
     parser.add_argument(
         "--config", type=str, help="config file to overwrite default config.")
-    parser.add_argument("--data", type=str, help="input")
+    parser.add_argument("--train-metadata", type=str, help="training data.")
+    parser.add_argument("--dev-metadata", type=str, help="dev data.")
     parser.add_argument("--output-dir", type=str, help="output dir.")
     parser.add_argument(
         "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
diff --git a/paddlespeech/t2s/models/wavernn/wavernn_updater.py b/paddlespeech/t2s/models/wavernn/wavernn_updater.py
index e6064e4cb..b2756d00c 100644
--- a/paddlespeech/t2s/models/wavernn/wavernn_updater.py
+++ b/paddlespeech/t2s/models/wavernn/wavernn_updater.py
@@ -21,8 +21,6 @@ from paddle.io import DataLoader
 from paddle.nn import Layer
 from paddle.optimizer import Optimizer
 
-from paddlespeech.t2s.datasets.vocoder_batch_fn import decode_mu_law
-from paddlespeech.t2s.datasets.vocoder_batch_fn import label_2_float
 from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
 from paddlespeech.t2s.training.reporter import report
 from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
@@ -156,31 +154,22 @@ class WaveRNNEvaluator(StandardEvaluator):
 
         losses_dict["loss"] = float(loss)
 
-        self.iteration = ITERATION
-        if self.iteration % self.config.gen_eval_samples_interval_steps == 0:
-            self.gen_valid_samples()
-
         self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
                               for k, v in losses_dict.items())
         self.logger.info(self.msg)
 
     def gen_valid_samples(self):
 
-        for i, (mel, wav) in enumerate(self.valid_generate_loader):
+        for i, item in enumerate(self.valid_generate_loader):
             if i >= self.config.generate_num:
-                print("before break")
                 break
             print(
                 '\n| Generating: {}/{}'.format(i + 1, self.config.generate_num))
-            wav = wav[0]
-            if self.mode == 'MOL':
-                bits = 16
-            else:
-                bits = self.config.model.bits
-            if self.config.mu_law and self.mode != 'MOL':
-                wav = decode_mu_law(wav, 2**bits, from_labels=True)
-            else:
-                wav = label_2_float(wav, bits)
+
+            mel = item['feats']
+            wav = item['wave']
+            wav = wav.squeeze(0)
+
             origin_save_path = self.valid_samples_dir / '{}_steps_{}_target.wav'.format(
                 self.iteration, i)
             sf.write(origin_save_path, wav.numpy(), samplerate=self.config.fs)
@@ -193,11 +182,20 @@ class WaveRNNEvaluator(StandardEvaluator):
             gen_save_path = str(self.valid_samples_dir /
                                 '{}_steps_{}_{}.wav'.format(self.iteration, i,
                                                             batch_str))
-            # (1, C_aux, T) -> (T, C_aux)
-            mel = mel.squeeze(0).transpose([1, 0])
+            # (1, T, C_aux) -> (T, C_aux)
+            mel = mel.squeeze(0)
             gen_sample = self.model.generate(
                 mel, self.config.inference.gen_batched,
                 self.config.inference.target, self.config.inference.overlap,
                 self.config.mu_law)
             sf.write(
                 gen_save_path, gen_sample.numpy(), samplerate=self.config.fs)
+
+    def __call__(self, trainer=None):
+        summary = self.evaluate()
+        for k, v in summary.items():
+            report(k, v)
+        # gen samples at then end of evaluate
+        self.iteration = ITERATION
+        if self.iteration % self.config.gen_eval_samples_interval_steps == 0:
+            self.gen_valid_samples()