From bd0d69ca7417fef97a9d0e4ab15f7fbead7c9ebb Mon Sep 17 00:00:00 2001 From: TianYuan Date: Tue, 18 Apr 2023 21:31:58 +0800 Subject: [PATCH] [TTS]add StarGANv2VC preprocess (#3163) --- examples/vctk/vc3/conf/default.yaml | 19 +- examples/vctk/vc3/local/preprocess.sh | 27 ++- paddlespeech/t2s/datasets/am_batch_fn.py | 78 ++++++- paddlespeech/t2s/datasets/data_table.py | 53 +++++ .../t2s/exps/starganv2_vc/normalize.py | 101 +++++++++ .../t2s/exps/starganv2_vc/preprocess.py | 214 ++++++++++++++++++ paddlespeech/t2s/exps/starganv2_vc/vc.py | 13 +- .../t2s/models/starganv2_vc/starganv2_vc.py | 3 +- 8 files changed, 485 insertions(+), 23 deletions(-) create mode 100644 paddlespeech/t2s/exps/starganv2_vc/normalize.py create mode 100644 paddlespeech/t2s/exps/starganv2_vc/preprocess.py diff --git a/examples/vctk/vc3/conf/default.yaml b/examples/vctk/vc3/conf/default.yaml index 461989024..b1168a40e 100644 --- a/examples/vctk/vc3/conf/default.yaml +++ b/examples/vctk/vc3/conf/default.yaml @@ -1,12 +1,23 @@ ########################################################### # FEATURE EXTRACTION SETTING # ########################################################### -# 其实没用上,其实用的是 16000 -sr: 24000 +# 源码 load 的时候用的 24k, 提取 mel 用的 16k, 后续 load 和提取 mel 都要改成 24k +fs: 16000 n_fft: 2048 -win_length: 1200 -hop_length: 300 +n_shift: 300 +win_length: 1200 # Window length.(in samples) 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +fmin: 0 # Minimum frequency of Mel basis. +fmax: 8000 # Maximum frequency of Mel basis. sr // 2 n_mels: 80 +# only for StarGANv2 VC +norm: # None here +htk: True +power: 2.0 + + ########################################################### # MODEL SETTING # ########################################################### diff --git a/examples/vctk/vc3/local/preprocess.sh b/examples/vctk/vc3/local/preprocess.sh index ea0fbc43d..058171c58 100755 --- a/examples/vctk/vc3/local/preprocess.sh +++ b/examples/vctk/vc3/local/preprocess.sh @@ -6,13 +6,32 @@ stop_stage=100 config_path=$1 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=vctk \ + --rootdir=~/datasets/VCTK-Corpus-0.92/ \ + --dumpdir=dump \ + --config=${config_path} \ + --num-cpu=20 fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - -fi - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then +echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speaker-dict=dump/speaker_id_map.txt fi diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 8ec15f5dd..5a18bc216 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -669,18 +669,72 @@ def vits_multi_spk_batch_fn(examples): return batch -# 未完成 -def starganv2_vc_batch_fn(examples): - batch = { - "x_real": None, - "y_org": None, - "x_ref": None, - "x_ref2": None, - "y_trg": None, - "z_trg": None, - "z_trg2": None, - } - return batch +# 因为要传参数,所以需要额外构建 +def build_starganv2_vc_collate_fn(latent_dim: int=16, max_mel_length: int=192): + + return StarGANv2VCCollateFn( + latent_dim=latent_dim, max_mel_length=max_mel_length) + + +class StarGANv2VCCollateFn: + """Functor class of common_collate_fn()""" + + def __init__(self, latent_dim: int=16, max_mel_length: int=192): + self.latent_dim = latent_dim + self.max_mel_length = max_mel_length + + def random_clip(self, mel: np.array): + # [80, T] + mel_length = mel.shape[1] + if mel_length > self.max_mel_length: + random_start = np.random.randint(0, + mel_length - self.max_mel_length) + mel = mel[:, random_start:random_start + self.max_mel_length] + return mel + + def __call__(self, exmaples): + return self.starganv2_vc_batch_fn(exmaples) + + def starganv2_vc_batch_fn(self, examples): + batch_size = len(examples) + + label = [np.array(item["label"], dtype=np.int64) for item in examples] + ref_label = [ + np.array(item["ref_label"], dtype=np.int64) for item in examples + ] + + # 需要对 mel 进行裁剪 + mel = [self.random_clip(item["mel"]) for item in examples] + ref_mel = [self.random_clip(item["ref_mel"]) for item in examples] + ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples] + + mel = batch_sequences(mel) + ref_mel = batch_sequences(ref_mel) + ref_mel_2 = batch_sequences(ref_mel_2) + + # convert each batch to paddle.Tensor + # (B,) + label = paddle.to_tensor(label) + ref_label = paddle.to_tensor(ref_label) + # [B, 80, T] -> [B, 1, 80, T] + mel = paddle.to_tensor(mel) + ref_mel = paddle.to_tensor(ref_mel) + ref_mel_2 = paddle.to_tensor(ref_mel_2) + + z_trg = paddle.randn(batch_size, self.latent_dim) + z_trg2 = paddle.randn(batch_size, self.latent_dim) + + batch = { + "x_real": mels, + "y_org": labels, + "x_ref": ref_mels, + "x_ref2": ref_mels_2, + "y_trg": ref_labels, + "z_trg": z_trg, + "z_trg2": z_trg2 + } + + return batch # for PaddleSlim diff --git a/paddlespeech/t2s/datasets/data_table.py b/paddlespeech/t2s/datasets/data_table.py index c9815af21..4ac675461 100644 --- a/paddlespeech/t2s/datasets/data_table.py +++ b/paddlespeech/t2s/datasets/data_table.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from multiprocessing import Manager from typing import Any from typing import Callable from typing import Dict from typing import List +import numpy as np from paddle.io import Dataset @@ -131,3 +133,54 @@ class DataTable(Dataset): The length of the dataset """ return len(self.data) + + +class StarGANv2VCDataTable(DataTable): + def __init__(self, data: List[Dict[str, Any]]): + super().__init__(data) + raw_data = data + spk_id_set = list(set([item['spk_id'] for item in raw_data])) + data_list_per_class = {} + for spk_id in spk_id_set: + data_list_per_class[spk_id] = [] + for item in raw_data: + for spk_id in spk_id_set: + if item['spk_id'] == spk_id: + data_list_per_class[spk_id].append(item) + self.data_list_per_class = data_list_per_class + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """Get an example given an index. + Args: + idx (int): Index of the example to get + + Returns: + Dict[str, Any]: A converted example + """ + if self.use_cache and self.caches[idx] is not None: + return self.caches[idx] + + data = self._get_metadata(idx) + + # 裁剪放到 batch_fn 里面 + # 返回一个字典 + """ + {'utt_id': 'p225_111', 'spk_id': '1', 'speech': 'path of *.npy'} + """ + ref_data = random.choice(self.data) + ref_label = ref_data['spk_id'] + ref_data_2 = random.choice(self.data_list_per_class[ref_label]) + # mel_tensor, label, ref_mel_tensor, ref2_mel_tensor, ref_label + new_example = { + 'utt_id': data['utt_id'], + 'mel': np.load(data['speech']), + 'label': int(data['spk_id']), + 'ref_mel': np.load(ref_data['speech']), + 'ref_mel_2': np.load(ref_data_2['speech']), + 'ref_label': int(ref_label) + } + + if self.use_cache: + self.caches[idx] = new_example + + return new_example diff --git a/paddlespeech/t2s/exps/starganv2_vc/normalize.py b/paddlespeech/t2s/exps/starganv2_vc/normalize.py new file mode 100644 index 000000000..c063c46f7 --- /dev/null +++ b/paddlespeech/t2s/exps/starganv2_vc/normalize.py @@ -0,0 +1,101 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalize feature files and dump them.""" +import argparse +import logging +from operator import itemgetter +from pathlib import Path + +import jsonlines +import numpy as np +import tqdm + +from paddlespeech.t2s.datasets.data_table import DataTable + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") + + args = parser.parse_args() + + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, converters={ + "speech": np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + vocab_speaker = {} + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + + # process each file + output_metadata = [] + + for item in tqdm.tqdm(dataset): + utt_id = item['utt_id'] + speech = item['speech'] + + # normalize + # 这里暂时写死 + mean, std = -4, 4 + speech = (speech - mean) / std + speech_path = dumpdir / f"{utt_id}_speech.npy" + np.save(speech_path, speech.astype(np.float32), allow_pickle=False) + + spk_id = vocab_speaker[item["speaker"]] + record = { + "utt_id": item['utt_id'], + "spk_id": spk_id, + "speech": str(speech_path), + } + + output_metadata.append(record) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/starganv2_vc/preprocess.py b/paddlespeech/t2s/exps/starganv2_vc/preprocess.py new file mode 100644 index 000000000..053c3b32a --- /dev/null +++ b/paddlespeech/t2s/exps/starganv2_vc/preprocess.py @@ -0,0 +1,214 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from concurrent.futures import ThreadPoolExecutor +from operator import itemgetter +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines +import librosa +import numpy as np +import tqdm +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map + +speaker_set = set() + + +def process_sentence(config: Dict[str, Any], + fp: Path, + output_dir: Path, + mel_extractor=None): + utt_id = fp.stem + # for vctk + if utt_id.endswith("_mic2"): + utt_id = utt_id[:-5] + speaker = utt_id.split('_')[0] + speaker_set.add(speaker) + # 需要额外获取 speaker + record = None + # reading, resampling may occur + # 源码的 bug, 读取的时候按照 24000 读取,但是提取 mel 的时候按照 16000 提取 + # 具体参考 https://github.com/PaddlePaddle/PaddleSpeech/blob/c7d24ba42c377fe4c0765c6b1faa202a9aeb136f/paddlespeech/t2s/exps/starganv2_vc/vc.py#L165 + # 之后需要换成按照 24000 读取和按照 24000 提取 mel + wav, _ = librosa.load(str(fp), sr=24000) + max_value = np.abs(wav).max() + if max_value > 1.0: + wav = wav / max_value + assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs( + wav).max() <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + # extract mel feats + # 注意这里 base = 'e', 后续需要换成 base='10', 我们其他 TTS 模型都是 base='10' + logmel = mel_extractor.get_log_mel_fbank(wav, base='e') + mel_path = output_dir / (utt_id + "_speech.npy") + np.save(mel_path, logmel) + record = {"utt_id": utt_id, "speech": str(mel_path), "speaker": speaker} + return record + + +def process_sentences( + config, + fps: List[Path], + output_dir: Path, + mel_extractor=None, + nprocs: int=1, ): + if nprocs == 1: + results = [] + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + output_dir=output_dir, + mel_extractor=mel_extractor) + if record: + results.append(record) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp in fps: + future = pool.submit(process_sentence, config, fp, + output_dir, mel_extractor) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + record = ft.result() + if record: + results.append(record) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features.") + + parser.add_argument( + "--dataset", + default="vctk", + type=str, + help="name of dataset, should in {vctk} now") + + parser.add_argument( + "--rootdir", default=None, type=str, help="directory to dataset.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + + parser.add_argument("--config", type=str, help="StarGANv2VC config file.") + + parser.add_argument( + "--num-cpu", type=int, default=1, help="number of process.") + + args = parser.parse_args() + + rootdir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + assert rootdir.is_dir() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + if args.dataset == "vctk": + sub_num_dev = 5 + wav_dir = rootdir / "wav48_silence_trimmed" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + # only for test + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + else: + print("dataset should in {vctk} now!") + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # Extractor + mel_extractor = LogMelFBank( + sr=config.fs, + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax, + # None here + norm=config.norm, + htk=config.htk, + power=config.power) + + # process for the 3 sections + if train_wav_files: + process_sentences( + config=config, + fps=train_wav_files, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, + nprocs=args.num_cpu) + if dev_wav_files: + process_sentences( + config=config, + fps=dev_wav_files, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, + nprocs=args.num_cpu) + if test_wav_files: + process_sentences( + config=config, + fps=test_wav_files, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, + nprocs=args.num_cpu) + + speaker_id_map_path = dumpdir / "speaker_id_map.txt" + get_spk_id_map(speaker_set, speaker_id_map_path) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/starganv2_vc/vc.py b/paddlespeech/t2s/exps/starganv2_vc/vc.py index ffb257418..24d3dcf8b 100644 --- a/paddlespeech/t2s/exps/starganv2_vc/vc.py +++ b/paddlespeech/t2s/exps/starganv2_vc/vc.py @@ -57,9 +57,10 @@ def get_mel_extractor(): def preprocess(wave, mel_extractor): + # (T, 80) logmel = mel_extractor.get_log_mel_fbank(wave, base='e') - # [1, 80, 1011] mean, std = -4, 4 + # [1, 80, T] mel_tensor = (paddle.to_tensor(logmel.T).unsqueeze(0) - mean) / std return mel_tensor @@ -67,6 +68,7 @@ def preprocess(wave, mel_extractor): def compute_style(speaker_dicts, mel_extractor, style_encoder, mapping_network): reference_embeddings = {} for key, (path, speaker) in speaker_dicts.items(): + # path = '' if path == '': label = paddle.to_tensor([speaker], dtype=paddle.int64) latent_dim = mapping_network.shared[0].weight.shape[0] @@ -164,6 +166,15 @@ def voice_conversion(args, uncompress_path): wave, sr = librosa.load(args.source_path, sr=24000) source = preprocess(wave=wave, mel_extractor=mel_extractor) + # # 测试 preprocess.py 的输出是否 ok + # # 直接用 raw 然后 norm 的在这里 ok + # # 直接用 norm 在这里 ok + # import numpy as np + # source = np.load("~/PaddleSpeech_stargan_preprocess/PaddleSpeech/examples/vctk/vc3/dump/train/norm/p329_414_speech.npy") + # # !!!对 mel_extractor norm 后的操作 + # # [1, 80, T] + # source = paddle.to_tensor(source.T).unsqueeze(0) + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) orig_wav_name = str(output_dir / 'orig_voc.wav') diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py index 2a96b30c6..2b6775c48 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py @@ -431,14 +431,13 @@ class MappingNetwork(nn.Layer): """Calculate forward propagation. Args: z(Tensor(float32)): - Shape (B, 1, n_mels, T). + Shape (B, latent_dim). y(Tensor(float32)): speaker label. Shape (B, ). Returns: Tensor: Shape (style_dim, ) """ - h = self.shared(z) out = [] for layer in self.unshared: