[TTS]add StarGANv2VC preprocess (#3163)
parent
c7d24ba42c
commit
bd0d69ca74
@ -0,0 +1,101 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Normalize feature files and dump them."""
|
||||
import argparse
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory including feature files to be normalized. "
|
||||
"you need to specify either *-scp or rootdir.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump normalized feature files.")
|
||||
|
||||
parser.add_argument(
|
||||
"--speaker-dict", type=str, default=None, help="speaker id map file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
# use absolute path
|
||||
dumpdir = dumpdir.resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# get dataset
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata, converters={
|
||||
"speech": np.load,
|
||||
})
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
vocab_speaker = {}
|
||||
with open(args.speaker_dict, 'rt') as f:
|
||||
spk_id = [line.strip().split() for line in f.readlines()]
|
||||
for spk, id in spk_id:
|
||||
vocab_speaker[spk] = int(id)
|
||||
|
||||
# process each file
|
||||
output_metadata = []
|
||||
|
||||
for item in tqdm.tqdm(dataset):
|
||||
utt_id = item['utt_id']
|
||||
speech = item['speech']
|
||||
|
||||
# normalize
|
||||
# 这里暂时写死
|
||||
mean, std = -4, 4
|
||||
speech = (speech - mean) / std
|
||||
speech_path = dumpdir / f"{utt_id}_speech.npy"
|
||||
np.save(speech_path, speech.astype(np.float32), allow_pickle=False)
|
||||
|
||||
spk_id = vocab_speaker[item["speaker"]]
|
||||
record = {
|
||||
"utt_id": item['utt_id'],
|
||||
"spk_id": spk_id,
|
||||
"speech": str(speech_path),
|
||||
}
|
||||
|
||||
output_metadata.append(record)
|
||||
output_metadata.sort(key=itemgetter('utt_id'))
|
||||
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||
for item in output_metadata:
|
||||
writer.write(item)
|
||||
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,214 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import jsonlines
|
||||
import librosa
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
|
||||
|
||||
speaker_set = set()
|
||||
|
||||
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
output_dir: Path,
|
||||
mel_extractor=None):
|
||||
utt_id = fp.stem
|
||||
# for vctk
|
||||
if utt_id.endswith("_mic2"):
|
||||
utt_id = utt_id[:-5]
|
||||
speaker = utt_id.split('_')[0]
|
||||
speaker_set.add(speaker)
|
||||
# 需要额外获取 speaker
|
||||
record = None
|
||||
# reading, resampling may occur
|
||||
# 源码的 bug, 读取的时候按照 24000 读取,但是提取 mel 的时候按照 16000 提取
|
||||
# 具体参考 https://github.com/PaddlePaddle/PaddleSpeech/blob/c7d24ba42c377fe4c0765c6b1faa202a9aeb136f/paddlespeech/t2s/exps/starganv2_vc/vc.py#L165
|
||||
# 之后需要换成按照 24000 读取和按照 24000 提取 mel
|
||||
wav, _ = librosa.load(str(fp), sr=24000)
|
||||
max_value = np.abs(wav).max()
|
||||
if max_value > 1.0:
|
||||
wav = wav / max_value
|
||||
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(
|
||||
wav).max() <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
# extract mel feats
|
||||
# 注意这里 base = 'e', 后续需要换成 base='10', 我们其他 TTS 模型都是 base='10'
|
||||
logmel = mel_extractor.get_log_mel_fbank(wav, base='e')
|
||||
mel_path = output_dir / (utt_id + "_speech.npy")
|
||||
np.save(mel_path, logmel)
|
||||
record = {"utt_id": utt_id, "speech": str(mel_path), "speaker": speaker}
|
||||
return record
|
||||
|
||||
|
||||
def process_sentences(
|
||||
config,
|
||||
fps: List[Path],
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
nprocs: int=1, ):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp in tqdm.tqdm(fps, total=len(fps)):
|
||||
record = process_sentence(
|
||||
config=config,
|
||||
fp=fp,
|
||||
output_dir=output_dir,
|
||||
mel_extractor=mel_extractor)
|
||||
if record:
|
||||
results.append(record)
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp in fps:
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
output_dir, mel_extractor)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
results = []
|
||||
for ft in futures:
|
||||
record = ft.result()
|
||||
if record:
|
||||
results.append(record)
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||
for item in results:
|
||||
writer.write(item)
|
||||
print("Done")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default="vctk",
|
||||
type=str,
|
||||
help="name of dataset, should in {vctk} now")
|
||||
|
||||
parser.add_argument(
|
||||
"--rootdir", default=None, type=str, help="directory to dataset.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump feature files.")
|
||||
|
||||
parser.add_argument("--config", type=str, help="StarGANv2VC config file.")
|
||||
|
||||
parser.add_argument(
|
||||
"--num-cpu", type=int, default=1, help="number of process.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
rootdir = Path(args.rootdir).expanduser()
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
# use absolute path
|
||||
dumpdir = dumpdir.resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
assert rootdir.is_dir()
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
if args.dataset == "vctk":
|
||||
sub_num_dev = 5
|
||||
wav_dir = rootdir / "wav48_silence_trimmed"
|
||||
train_wav_files = []
|
||||
dev_wav_files = []
|
||||
test_wav_files = []
|
||||
# only for test
|
||||
for speaker in os.listdir(wav_dir):
|
||||
wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac")))
|
||||
if len(wav_files) > 100:
|
||||
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||
test_wav_files += wav_files[-sub_num_dev:]
|
||||
else:
|
||||
train_wav_files += wav_files
|
||||
|
||||
else:
|
||||
print("dataset should in {vctk} now!")
|
||||
|
||||
train_dump_dir = dumpdir / "train" / "raw"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extractor
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=config.fs,
|
||||
n_fft=config.n_fft,
|
||||
hop_length=config.n_shift,
|
||||
win_length=config.win_length,
|
||||
window=config.window,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax,
|
||||
# None here
|
||||
norm=config.norm,
|
||||
htk=config.htk,
|
||||
power=config.power)
|
||||
|
||||
# process for the 3 sections
|
||||
if train_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=train_wav_files,
|
||||
output_dir=train_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
if dev_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=dev_wav_files,
|
||||
output_dir=dev_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
if test_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=test_wav_files,
|
||||
output_dir=test_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
|
||||
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
|
||||
get_spk_id_map(speaker_set, speaker_id_map_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in new issue