parent
49fd55dc16
commit
fb0acd40a2
@ -0,0 +1,68 @@
|
|||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
fs: 24000 # Sampling rate.
|
||||||
|
n_fft: 2048 # FFT size (samples).
|
||||||
|
n_shift: 300 # Hop size (samples). 12.5ms
|
||||||
|
win_length: 1200 # Window length (samples). 50ms
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
n_mels: 80 # Number of mel basis.
|
||||||
|
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
|
||||||
|
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
|
||||||
|
mu_law: True # Recommended to suppress noise if using raw bitsexit()
|
||||||
|
peak_norm: True
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
model:
|
||||||
|
rnn_dims: 512 # Hidden dims of RNN Layers.
|
||||||
|
fc_dims: 512
|
||||||
|
bits: 9 # Bit depth of signal
|
||||||
|
aux_context_window: 2
|
||||||
|
aux_channels: 80 # Number of channels for auxiliary feature conv.
|
||||||
|
# Must be the same as num_mels.
|
||||||
|
upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size, same with pwgan here
|
||||||
|
compute_dims: 128
|
||||||
|
res_out_dims: 128
|
||||||
|
res_blocks: 10
|
||||||
|
mode: RAW # either 'raw'(softmax on raw bits) or 'mold' (sample from mixture of logistics)
|
||||||
|
inference:
|
||||||
|
gen_batched: True # whether to genenate sample in batch mode
|
||||||
|
target: 12000 # target number of samples to be generated in each batch entry
|
||||||
|
overlap: 600 # number of samples for crossfading between batches
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA LOADER SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 64 # Batch size.
|
||||||
|
batch_max_steps: 4500 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||||
|
num_workers: 2 # Number of workers in DataLoader.
|
||||||
|
valid_size: 50
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER SETTING #
|
||||||
|
###########################################################
|
||||||
|
grad_clip: 4.0
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# INTERVAL SETTING #
|
||||||
|
###########################################################
|
||||||
|
|
||||||
|
train_max_steps: 400000 # Number of training steps.
|
||||||
|
save_interval_steps: 5000 # Interval steps to save checkpoint.
|
||||||
|
eval_interval_steps: 1000 # Interval steps to evaluate the network.
|
||||||
|
gen_eval_samples_interval_steps: 5000 # the iteration interval of generating valid samples
|
||||||
|
generate_num: 5 # number of samples to generate at each checkpoint
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
num_snapshots: 10 # max number of snapshots to keep while training
|
||||||
|
seed: 42 # random seed for paddle, random, and np.random
|
@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
python3 ${BIN_DIR}/preprocess.py \
|
||||||
|
--input=~/datasets/BZNSYP/ \
|
||||||
|
--output=dump \
|
||||||
|
--dataset=csmsc \
|
||||||
|
--config=${config_path} \
|
||||||
|
--num-cpu=20
|
||||||
|
fi
|
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
test_input=$4
|
||||||
|
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/synthesize.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--input=${test_input} \
|
||||||
|
--output-dir=${train_output_path}/test
|
@ -0,0 +1,9 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
python ${BIN_DIR}/train.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--data=dump/ \
|
||||||
|
--output-dir=${train_output_path} \
|
||||||
|
--ngpu=1
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
export PYTHONDONTWRITEBYTECODE=1
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
MODEL=wavernn
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
|
gpus=0,1
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
conf_path=conf/default.yaml
|
||||||
|
train_output_path=exp/default
|
||||||
|
test_input=dump/mel_test
|
||||||
|
ckpt_name=snapshot_iter_100000.pdz
|
||||||
|
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
./local/preprocess.sh ${conf_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# prepare data
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# copy some test mels from dump
|
||||||
|
mkdir -p ${test_input}
|
||||||
|
cp -r dump/mel/00995*.npy ${test_input}
|
||||||
|
# synthesize
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ${test_input}|| exit -1
|
||||||
|
fi
|
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from paddle.io import Dataset
|
||||||
|
|
||||||
|
__all__ = ["CSMSCMetaData"]
|
||||||
|
|
||||||
|
|
||||||
|
class CSMSCMetaData(Dataset):
|
||||||
|
def __init__(self, root):
|
||||||
|
"""
|
||||||
|
:param root: the path of baker dataset
|
||||||
|
"""
|
||||||
|
self.root = os.path.abspath(root)
|
||||||
|
records = []
|
||||||
|
index = 1
|
||||||
|
self.meta_info = ["file_path", "text", "pinyin"]
|
||||||
|
|
||||||
|
metadata_path = os.path.join(root, "ProsodyLabeling/000001-010000.txt")
|
||||||
|
wav_dirs = os.path.join(self.root, "Wave")
|
||||||
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||||
|
while True:
|
||||||
|
line1 = f.readline().strip()
|
||||||
|
if not line1:
|
||||||
|
break
|
||||||
|
line2 = f.readline().strip()
|
||||||
|
strs = line1.split()
|
||||||
|
wav_fname = line1.split()[0].strip() + '.wav'
|
||||||
|
wav_filepath = os.path.join(wav_dirs, wav_fname)
|
||||||
|
text = strs[1].strip()
|
||||||
|
pinyin = line2
|
||||||
|
records.append([wav_filepath, text, pinyin])
|
||||||
|
|
||||||
|
self.records = records
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.records[i]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.records)
|
||||||
|
|
||||||
|
def get_meta_info(self):
|
||||||
|
return self.meta_info
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,157 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import tqdm
|
||||||
|
import yaml
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.data.get_feats import LogMelFBank
|
||||||
|
from paddlespeech.t2s.datasets import CSMSCMetaData
|
||||||
|
from paddlespeech.t2s.datasets import LJSpeechMetaData
|
||||||
|
from paddlespeech.t2s.datasets.vocoder_batch_fn import encode_mu_law
|
||||||
|
from paddlespeech.t2s.datasets.vocoder_batch_fn import float_2_label
|
||||||
|
|
||||||
|
|
||||||
|
class Transform(object):
|
||||||
|
def __init__(self, output_dir: Path, config):
|
||||||
|
self.fs = config.fs
|
||||||
|
self.peak_norm = config.peak_norm
|
||||||
|
self.bits = config.model.bits
|
||||||
|
self.mode = config.model.mode
|
||||||
|
self.mu_law = config.mu_law
|
||||||
|
|
||||||
|
self.wav_dir = output_dir / "wav"
|
||||||
|
self.mel_dir = output_dir / "mel"
|
||||||
|
self.wav_dir.mkdir(exist_ok=True)
|
||||||
|
self.mel_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
self.mel_extractor = LogMelFBank(
|
||||||
|
sr=config.fs,
|
||||||
|
n_fft=config.n_fft,
|
||||||
|
hop_length=config.n_shift,
|
||||||
|
win_length=config.win_length,
|
||||||
|
window=config.window,
|
||||||
|
n_mels=config.n_mels,
|
||||||
|
fmin=config.fmin,
|
||||||
|
fmax=config.fmax)
|
||||||
|
|
||||||
|
if self.mode != 'RAW' and self.mode != 'MOL':
|
||||||
|
raise RuntimeError('Unknown mode value - ', self.mode)
|
||||||
|
|
||||||
|
def __call__(self, example):
|
||||||
|
wav_path, _, _ = example
|
||||||
|
|
||||||
|
base_name = os.path.splitext(os.path.basename(wav_path))[0]
|
||||||
|
# print("self.sample_rate:",self.sample_rate)
|
||||||
|
wav, _ = librosa.load(wav_path, sr=self.fs)
|
||||||
|
peak = np.abs(wav).max()
|
||||||
|
if self.peak_norm or peak > 1.0:
|
||||||
|
wav /= peak
|
||||||
|
|
||||||
|
mel = self.mel_extractor.get_log_mel_fbank(wav).T
|
||||||
|
if self.mode == 'RAW':
|
||||||
|
if self.mu_law:
|
||||||
|
quant = encode_mu_law(wav, mu=2**self.bits)
|
||||||
|
else:
|
||||||
|
quant = float_2_label(wav, bits=self.bits)
|
||||||
|
elif self.mode == 'MOL':
|
||||||
|
quant = float_2_label(wav, bits=16)
|
||||||
|
|
||||||
|
mel = mel.astype(np.float32)
|
||||||
|
audio = quant.astype(np.int64)
|
||||||
|
|
||||||
|
np.save(str(self.wav_dir / base_name), audio)
|
||||||
|
np.save(str(self.mel_dir / base_name), mel)
|
||||||
|
|
||||||
|
return base_name, mel.shape[-1], audio.shape[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(config,
|
||||||
|
input_dir,
|
||||||
|
output_dir,
|
||||||
|
nprocs: int=1,
|
||||||
|
dataset_type: str="ljspeech"):
|
||||||
|
input_dir = Path(input_dir).expanduser()
|
||||||
|
'''
|
||||||
|
LJSpeechMetaData.records: [filename, normalized text, speaker name(ljspeech)]
|
||||||
|
CSMSCMetaData.records: [filename, normalized text, pinyin]
|
||||||
|
'''
|
||||||
|
if dataset_type == 'ljspeech':
|
||||||
|
dataset = LJSpeechMetaData(input_dir)
|
||||||
|
else:
|
||||||
|
dataset = CSMSCMetaData(input_dir)
|
||||||
|
output_dir = Path(output_dir).expanduser()
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
transform = Transform(output_dir, config)
|
||||||
|
|
||||||
|
file_names = []
|
||||||
|
|
||||||
|
pool = Pool(processes=nprocs)
|
||||||
|
|
||||||
|
for info in tqdm.tqdm(pool.imap(transform, dataset), total=len(dataset)):
|
||||||
|
base_name, mel_len, audio_len = info
|
||||||
|
file_names.append((base_name, mel_len, audio_len))
|
||||||
|
|
||||||
|
meta_data = pd.DataFrame.from_records(file_names)
|
||||||
|
meta_data.to_csv(
|
||||||
|
str(output_dir / "metadata.csv"), sep="\t", index=None, header=None)
|
||||||
|
print("saved meta data in to {}".format(
|
||||||
|
os.path.join(output_dir, "metadata.csv")))
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="create dataset")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="config file to overwrite default config.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input", type=str, help="path of the ljspeech dataset")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", type=str, help="path to save output dataset")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-cpu",
|
||||||
|
type=int,
|
||||||
|
default=cpu_count() // 2,
|
||||||
|
help="number of process.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="ljspeech",
|
||||||
|
help="The dataset to preprocess, ljspeech or csmsc")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.config, 'rt') as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
if args.dataset != "ljspeech" and args.dataset != "csmsc":
|
||||||
|
raise RuntimeError('Unknown dataset - ', args.dataset)
|
||||||
|
|
||||||
|
create_dataset(
|
||||||
|
config,
|
||||||
|
input_dir=args.input,
|
||||||
|
output_dir=args.output,
|
||||||
|
nprocs=args.num_cpu,
|
||||||
|
dataset_type=args.dataset)
|
@ -0,0 +1,89 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import soundfile as sf
|
||||||
|
import yaml
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.models.wavernn import WaveRNN
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Synthesize with WaveRNN.")
|
||||||
|
|
||||||
|
parser.add_argument("--config", type=str, help="GANVocoder config file.")
|
||||||
|
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--input",
|
||||||
|
type=str,
|
||||||
|
help="path of directory containing mel spectrogram (in .npy format)")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.config) as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
print(
|
||||||
|
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.ngpu == 0:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
elif args.ngpu > 0:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
else:
|
||||||
|
print("ngpu should >= 0 !")
|
||||||
|
|
||||||
|
model = WaveRNN(
|
||||||
|
hop_length=config.n_shift, sample_rate=config.fs, **config["model"])
|
||||||
|
state_dict = paddle.load(args.checkpoint)
|
||||||
|
model.set_state_dict(state_dict["main_params"])
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
mel_dir = Path(args.input).expanduser()
|
||||||
|
output_dir = Path(args.output_dir).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for file_path in sorted(mel_dir.iterdir()):
|
||||||
|
mel = np.load(str(file_path))
|
||||||
|
mel = paddle.to_tensor(mel)
|
||||||
|
mel = mel.transpose([1, 0])
|
||||||
|
# input shape is (T', C_aux)
|
||||||
|
audio = model.generate(
|
||||||
|
c=mel,
|
||||||
|
batched=config.inference.gen_batched,
|
||||||
|
target=config.inference.target,
|
||||||
|
overlap=config.inference.overlap,
|
||||||
|
mu_law=config.mu_law,
|
||||||
|
gen_display=True)
|
||||||
|
audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav")
|
||||||
|
sf.write(audio_path, audio.numpy(), samplerate=config.fs)
|
||||||
|
print("[synthesize] {} -> {}".format(file_path, audio_path))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,192 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import yaml
|
||||||
|
from paddle import DataParallel
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
from paddle.optimizer import Adam
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.data import dataset
|
||||||
|
from paddlespeech.t2s.datasets.vocoder_batch_fn import WaveRNNClip
|
||||||
|
from paddlespeech.t2s.datasets.vocoder_batch_fn import WaveRNNDataset
|
||||||
|
from paddlespeech.t2s.models.wavernn import WaveRNN
|
||||||
|
from paddlespeech.t2s.models.wavernn import WaveRNNEvaluator
|
||||||
|
from paddlespeech.t2s.models.wavernn import WaveRNNUpdater
|
||||||
|
from paddlespeech.t2s.modules.losses import discretized_mix_logistic_loss
|
||||||
|
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||||
|
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||||
|
from paddlespeech.t2s.training.seeding import seed_everything
|
||||||
|
from paddlespeech.t2s.training.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
def train_sp(args, config):
|
||||||
|
# decides device type and whether to run in parallel
|
||||||
|
# setup running environment correctly
|
||||||
|
world_size = paddle.distributed.get_world_size()
|
||||||
|
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
else:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
if world_size > 1:
|
||||||
|
paddle.distributed.init_parallel_env()
|
||||||
|
|
||||||
|
# set the random seed, it is a must for multiprocess training
|
||||||
|
seed_everything(config.seed)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
wavernn_dataset = WaveRNNDataset(args.data)
|
||||||
|
|
||||||
|
train_dataset, dev_dataset = dataset.split(
|
||||||
|
wavernn_dataset, len(wavernn_dataset) - config.valid_size)
|
||||||
|
|
||||||
|
batch_fn = WaveRNNClip(
|
||||||
|
mode=config.model.mode,
|
||||||
|
aux_context_window=config.model.aux_context_window,
|
||||||
|
hop_size=config.n_shift,
|
||||||
|
batch_max_steps=config.batch_max_steps,
|
||||||
|
bits=config.model.bits)
|
||||||
|
|
||||||
|
# collate function and dataloader
|
||||||
|
train_sampler = DistributedBatchSampler(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True)
|
||||||
|
dev_sampler = DistributedBatchSampler(
|
||||||
|
dev_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False)
|
||||||
|
print("samplers done!")
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
collate_fn=batch_fn,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
dev_dataloader = DataLoader(
|
||||||
|
dev_dataset,
|
||||||
|
collate_fn=batch_fn,
|
||||||
|
batch_sampler=dev_sampler,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
valid_generate_loader = DataLoader(dev_dataset, batch_size=1)
|
||||||
|
print("dataloaders done!")
|
||||||
|
|
||||||
|
model = WaveRNN(
|
||||||
|
hop_length=config.n_shift, sample_rate=config.fs, **config["model"])
|
||||||
|
if world_size > 1:
|
||||||
|
model = DataParallel(model)
|
||||||
|
print("model done!")
|
||||||
|
|
||||||
|
if config.model.mode == 'RAW':
|
||||||
|
criterion = paddle.nn.CrossEntropyLoss(axis=1)
|
||||||
|
elif config.model.mode == 'MOL':
|
||||||
|
criterion = discretized_mix_logistic_loss
|
||||||
|
else:
|
||||||
|
criterion = None
|
||||||
|
RuntimeError('Unknown model mode value - ', config.model.mode)
|
||||||
|
print("criterions done!")
|
||||||
|
clip = paddle.nn.ClipGradByGlobalNorm(config.grad_clip)
|
||||||
|
optimizer = Adam(
|
||||||
|
parameters=model.parameters(),
|
||||||
|
learning_rate=config.learning_rate,
|
||||||
|
grad_clip=clip)
|
||||||
|
|
||||||
|
print("optimizer done!")
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
config_name = args.config.split("/")[-1]
|
||||||
|
# copy conf to output_dir
|
||||||
|
shutil.copyfile(args.config, output_dir / config_name)
|
||||||
|
|
||||||
|
updater = WaveRNNUpdater(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
dataloader=train_dataloader,
|
||||||
|
output_dir=output_dir,
|
||||||
|
mode=config.model.mode)
|
||||||
|
|
||||||
|
evaluator = WaveRNNEvaluator(
|
||||||
|
model=model,
|
||||||
|
dataloader=dev_dataloader,
|
||||||
|
criterion=criterion,
|
||||||
|
output_dir=output_dir,
|
||||||
|
valid_generate_loader=valid_generate_loader,
|
||||||
|
config=config)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
updater,
|
||||||
|
stop_trigger=(config.train_max_steps, "iteration"),
|
||||||
|
out=output_dir)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
trainer.extend(
|
||||||
|
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
||||||
|
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
|
||||||
|
trainer.extend(
|
||||||
|
Snapshot(max_size=config.num_snapshots),
|
||||||
|
trigger=(config.save_interval_steps, 'iteration'))
|
||||||
|
|
||||||
|
print("Trainer Done!")
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Train a WaveRNN model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="config file to overwrite default config.")
|
||||||
|
parser.add_argument("--data", type=str, help="input")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.config, 'rt') as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
print(
|
||||||
|
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# dispatch
|
||||||
|
if args.ngpu > 1:
|
||||||
|
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||||
|
else:
|
||||||
|
train_sp(args, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .wavernn import *
|
||||||
|
from .wavernn_updater import *
|
@ -0,0 +1,592 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
from paddlespeech.t2s.datasets.vocoder_batch_fn import decode_mu_law
|
||||||
|
from paddlespeech.t2s.modules.losses import sample_from_discretized_mix_logistic
|
||||||
|
from paddlespeech.t2s.modules.nets_utils import initialize
|
||||||
|
from paddlespeech.t2s.modules.upsample import Stretch2D
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Layer):
|
||||||
|
def __init__(self, dims):
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
self.conv1 = nn.Conv1D(dims, dims, kernel_size=1, bias_attr=False)
|
||||||
|
self.conv2 = nn.Conv1D(dims, dims, kernel_size=1, bias_attr=False)
|
||||||
|
self.batch_norm1 = nn.BatchNorm1D(dims)
|
||||||
|
self.batch_norm2 = nn.BatchNorm1D(dims)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
conv -> bn -> relu -> conv -> bn + residual connection
|
||||||
|
'''
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.batch_norm1(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.batch_norm2(x)
|
||||||
|
return x + residual
|
||||||
|
|
||||||
|
|
||||||
|
class MelResNet(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
res_blocks: int=10,
|
||||||
|
compute_dims: int=128,
|
||||||
|
res_out_dims: int=128,
|
||||||
|
aux_channels: int=80,
|
||||||
|
aux_context_window: int=0):
|
||||||
|
super().__init__()
|
||||||
|
k_size = aux_context_window * 2 + 1
|
||||||
|
# pay attention here, the dim reduces aux_context_window * 2
|
||||||
|
self.conv_in = nn.Conv1D(
|
||||||
|
aux_channels, compute_dims, kernel_size=k_size, bias_attr=False)
|
||||||
|
self.batch_norm = nn.BatchNorm1D(compute_dims)
|
||||||
|
self.layers = nn.LayerList()
|
||||||
|
for _ in range(res_blocks):
|
||||||
|
self.layers.append(ResBlock(compute_dims))
|
||||||
|
self.conv_out = nn.Conv1D(compute_dims, res_out_dims, kernel_size=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Input tensor (B, in_dims, T).
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
Output tensor (B, res_out_dims, T).
|
||||||
|
'''
|
||||||
|
x = self.conv_in(x)
|
||||||
|
x = self.batch_norm(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
for f in self.layers:
|
||||||
|
x = f(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UpsampleNetwork(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
aux_channels: int=80,
|
||||||
|
upsample_scales: List[int]=[4, 5, 3, 5],
|
||||||
|
compute_dims: int=128,
|
||||||
|
res_blocks: int=10,
|
||||||
|
res_out_dims: int=128,
|
||||||
|
aux_context_window: int=2):
|
||||||
|
super().__init__()
|
||||||
|
# total_scale is the total Up sampling multiple
|
||||||
|
total_scale = np.prod(upsample_scales)
|
||||||
|
# TODO pad*total_scale is numpy.int64
|
||||||
|
self.indent = int(aux_context_window * total_scale)
|
||||||
|
self.resnet = MelResNet(
|
||||||
|
res_blocks=res_blocks,
|
||||||
|
aux_channels=aux_channels,
|
||||||
|
compute_dims=compute_dims,
|
||||||
|
res_out_dims=res_out_dims,
|
||||||
|
aux_context_window=aux_context_window)
|
||||||
|
self.resnet_stretch = Stretch2D(total_scale, 1)
|
||||||
|
self.up_layers = nn.LayerList()
|
||||||
|
for scale in upsample_scales:
|
||||||
|
k_size = (1, scale * 2 + 1)
|
||||||
|
padding = (0, scale)
|
||||||
|
stretch = Stretch2D(scale, 1)
|
||||||
|
|
||||||
|
conv = nn.Conv2D(
|
||||||
|
1, 1, kernel_size=k_size, padding=padding, bias_attr=False)
|
||||||
|
weight_ = paddle.full_like(conv.weight, 1. / k_size[1])
|
||||||
|
conv.weight.set_value(weight_)
|
||||||
|
self.up_layers.append(stretch)
|
||||||
|
self.up_layers.append(conv)
|
||||||
|
|
||||||
|
def forward(self, m):
|
||||||
|
'''
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
c : Tensor
|
||||||
|
Input tensor (B, C_aux, T).
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
Output tensor (B, (T - 2 * pad) * prob(upsample_scales), C_aux).
|
||||||
|
Tensor
|
||||||
|
Output tensor (B, (T - 2 * pad) * prob(upsample_scales), res_out_dims).
|
||||||
|
'''
|
||||||
|
# aux: [B, C_aux, T]
|
||||||
|
# -> [B, res_out_dims, T - 2 * aux_context_window]
|
||||||
|
# -> [B, 1, res_out_dims, T - 2 * aux_context_window]
|
||||||
|
aux = self.resnet(m).unsqueeze(1)
|
||||||
|
# aux: [B, 1, res_out_dims, T - 2 * aux_context_window]
|
||||||
|
# -> [B, 1, res_out_dims, (T - 2 * pad) * prob(upsample_scales)]
|
||||||
|
aux = self.resnet_stretch(aux)
|
||||||
|
# aux: [B, 1, res_out_dims, T * prob(upsample_scales)]
|
||||||
|
# -> [B, res_out_dims, T * prob(upsample_scales)]
|
||||||
|
aux = aux.squeeze(1)
|
||||||
|
# m: [B, C_aux, T] -> [B, 1, C_aux, T]
|
||||||
|
m = m.unsqueeze(1)
|
||||||
|
for f in self.up_layers:
|
||||||
|
m = f(m)
|
||||||
|
# m: [B, 1, C_aux, T*prob(upsample_scales)]
|
||||||
|
# -> [B, C_aux, T * prob(upsample_scales)]
|
||||||
|
# -> [B, C_aux, (T - 2 * pad) * prob(upsample_scales)]
|
||||||
|
m = m.squeeze(1)[:, :, self.indent:-self.indent]
|
||||||
|
# m: [B, (T - 2 * pad) * prob(upsample_scales), C_aux]
|
||||||
|
# aux: [B, (T - 2 * pad) * prob(upsample_scales), res_out_dims]
|
||||||
|
return m.transpose([0, 2, 1]), aux.transpose([0, 2, 1])
|
||||||
|
|
||||||
|
|
||||||
|
class WaveRNN(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rnn_dims: int=512,
|
||||||
|
fc_dims: int=512,
|
||||||
|
bits: int=9,
|
||||||
|
aux_context_window: int=2,
|
||||||
|
upsample_scales: List[int]=[4, 5, 3, 5],
|
||||||
|
aux_channels: int=80,
|
||||||
|
compute_dims: int=128,
|
||||||
|
res_out_dims: int=128,
|
||||||
|
res_blocks: int=10,
|
||||||
|
hop_length: int=300,
|
||||||
|
sample_rate: int=24000,
|
||||||
|
mode='RAW',
|
||||||
|
init_type: str="xavier_uniform", ):
|
||||||
|
'''
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rnn_dims : int, optional
|
||||||
|
Hidden dims of RNN Layers.
|
||||||
|
fc_dims : int, optional
|
||||||
|
Dims of FC Layers.
|
||||||
|
bits : int, optional
|
||||||
|
bit depth of signal.
|
||||||
|
aux_context_window : int, optional
|
||||||
|
The context window size of the first convolution applied to the
|
||||||
|
auxiliary input, by default 2
|
||||||
|
upsample_scales : List[int], optional
|
||||||
|
Upsample scales of the upsample network.
|
||||||
|
aux_channels : int, optional
|
||||||
|
Auxiliary channel of the residual blocks.
|
||||||
|
compute_dims : int, optional
|
||||||
|
Dims of Conv1D in MelResNet.
|
||||||
|
res_out_dims : int, optional
|
||||||
|
Dims of output in MelResNet.
|
||||||
|
res_blocks : int, optional
|
||||||
|
Number of residual blocks.
|
||||||
|
mode : str, optional
|
||||||
|
Output mode of the WaveRNN vocoder. `MOL` for Mixture of Logistic Distribution,
|
||||||
|
and `RAW` for quantized bits as the model's output.
|
||||||
|
init_type : str
|
||||||
|
How to initialize parameters.
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
self.mode = mode
|
||||||
|
self.aux_context_window = aux_context_window
|
||||||
|
if self.mode == 'RAW':
|
||||||
|
self.n_classes = 2**bits
|
||||||
|
elif self.mode == 'MOL':
|
||||||
|
self.n_classes = 30
|
||||||
|
else:
|
||||||
|
RuntimeError('Unknown model mode value - ', self.mode)
|
||||||
|
|
||||||
|
# List of rnns to call 'flatten_parameters()' on
|
||||||
|
self._to_flatten = []
|
||||||
|
|
||||||
|
self.rnn_dims = rnn_dims
|
||||||
|
self.aux_dims = res_out_dims // 4
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
|
||||||
|
# initialize parameters
|
||||||
|
initialize(self, init_type)
|
||||||
|
|
||||||
|
self.upsample = UpsampleNetwork(
|
||||||
|
aux_channels=aux_channels,
|
||||||
|
upsample_scales=upsample_scales,
|
||||||
|
compute_dims=compute_dims,
|
||||||
|
res_blocks=res_blocks,
|
||||||
|
res_out_dims=res_out_dims,
|
||||||
|
aux_context_window=aux_context_window)
|
||||||
|
self.I = nn.Linear(aux_channels + self.aux_dims + 1, rnn_dims)
|
||||||
|
|
||||||
|
self.rnn1 = nn.GRU(rnn_dims, rnn_dims)
|
||||||
|
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims)
|
||||||
|
self._to_flatten += [self.rnn1, self.rnn2]
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
|
||||||
|
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
|
||||||
|
self.fc3 = nn.Linear(fc_dims, self.n_classes)
|
||||||
|
|
||||||
|
# Avoid fragmentation of RNN parameters and associated warning
|
||||||
|
self._flatten_parameters()
|
||||||
|
|
||||||
|
nn.initializer.set_global_initializer(None)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
'''
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
wav sequence, [B, T]
|
||||||
|
c : Tensor
|
||||||
|
mel spectrogram [B, C_aux, T']
|
||||||
|
|
||||||
|
T = (T' - 2 * aux_context_window ) * hop_length
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
[B, T, n_classes]
|
||||||
|
'''
|
||||||
|
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||||
|
# the model gets replicated, making it no longer guaranteed that the
|
||||||
|
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||||
|
self._flatten_parameters()
|
||||||
|
|
||||||
|
bsize = paddle.shape(x)[0]
|
||||||
|
h1 = paddle.zeros([1, bsize, self.rnn_dims])
|
||||||
|
h2 = paddle.zeros([1, bsize, self.rnn_dims])
|
||||||
|
# c: [B, T, C_aux]
|
||||||
|
# aux: [B, T, res_out_dims]
|
||||||
|
c, aux = self.upsample(c)
|
||||||
|
|
||||||
|
aux_idx = [self.aux_dims * i for i in range(5)]
|
||||||
|
a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
|
||||||
|
a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
|
||||||
|
a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
|
||||||
|
a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
|
||||||
|
|
||||||
|
x = paddle.concat([x.unsqueeze(-1), c, a1], axis=2)
|
||||||
|
x = self.I(x)
|
||||||
|
res = x
|
||||||
|
x, _ = self.rnn1(x, h1)
|
||||||
|
|
||||||
|
x = x + res
|
||||||
|
res = x
|
||||||
|
x = paddle.concat([x, a2], axis=2)
|
||||||
|
x, _ = self.rnn2(x, h2)
|
||||||
|
|
||||||
|
x = x + res
|
||||||
|
x = paddle.concat([x, a3], axis=2)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
|
||||||
|
x = paddle.concat([x, a4], axis=2)
|
||||||
|
x = F.relu(self.fc2(x))
|
||||||
|
|
||||||
|
return self.fc3(x)
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def generate(self,
|
||||||
|
c,
|
||||||
|
batched: bool=True,
|
||||||
|
target: int=12000,
|
||||||
|
overlap: int=600,
|
||||||
|
mu_law: bool=True,
|
||||||
|
gen_display: bool=False):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
c : Tensor
|
||||||
|
input mels, (T', C_aux)
|
||||||
|
batched : bool
|
||||||
|
generate in batch or not
|
||||||
|
target : int
|
||||||
|
target number of samples to be generated in each batch entry
|
||||||
|
overlap : int
|
||||||
|
number of samples for crossfading between batches
|
||||||
|
mu_law : bool
|
||||||
|
use mu law or not
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
wav sequence
|
||||||
|
Output (T' * prod(upsample_scales), out_channels, C_out).
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
mu_law = mu_law if self.mode == 'RAW' else False
|
||||||
|
|
||||||
|
output = []
|
||||||
|
start = time.time()
|
||||||
|
rnn1 = self.get_gru_cell(self.rnn1)
|
||||||
|
rnn2 = self.get_gru_cell(self.rnn2)
|
||||||
|
# pseudo batch
|
||||||
|
# (T, C_aux) -> (1, C_aux, T)
|
||||||
|
c = paddle.transpose(c, [1, 0]).unsqueeze(0)
|
||||||
|
|
||||||
|
wave_len = (paddle.shape(c)[-1] - 1) * self.hop_length
|
||||||
|
# TODO remove two transpose op by modifying function pad_tensor
|
||||||
|
c = self.pad_tensor(
|
||||||
|
c.transpose([0, 2, 1]), pad=self.aux_context_window,
|
||||||
|
side='both').transpose([0, 2, 1])
|
||||||
|
c, aux = self.upsample(c)
|
||||||
|
|
||||||
|
if batched:
|
||||||
|
# (num_folds, target + 2 * overlap, features)
|
||||||
|
c = self.fold_with_overlap(c, target, overlap)
|
||||||
|
aux = self.fold_with_overlap(aux, target, overlap)
|
||||||
|
|
||||||
|
b_size, seq_len, _ = paddle.shape(c)
|
||||||
|
h1 = paddle.zeros([b_size, self.rnn_dims])
|
||||||
|
h2 = paddle.zeros([b_size, self.rnn_dims])
|
||||||
|
x = paddle.zeros([b_size, 1])
|
||||||
|
|
||||||
|
d = self.aux_dims
|
||||||
|
aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]
|
||||||
|
|
||||||
|
for i in range(seq_len):
|
||||||
|
m_t = c[:, i, :]
|
||||||
|
|
||||||
|
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
|
||||||
|
x = paddle.concat([x, m_t, a1_t], axis=1)
|
||||||
|
x = self.I(x)
|
||||||
|
h1, _ = rnn1(x, h1)
|
||||||
|
x = x + h1
|
||||||
|
inp = paddle.concat([x, a2_t], axis=1)
|
||||||
|
h2, _ = rnn2(inp, h2)
|
||||||
|
|
||||||
|
x = x + h2
|
||||||
|
x = paddle.concat([x, a3_t], axis=1)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
|
||||||
|
x = paddle.concat([x, a4_t], axis=1)
|
||||||
|
x = F.relu(self.fc2(x))
|
||||||
|
|
||||||
|
logits = self.fc3(x)
|
||||||
|
|
||||||
|
if self.mode == 'MOL':
|
||||||
|
sample = sample_from_discretized_mix_logistic(
|
||||||
|
logits.unsqueeze(0).transpose([0, 2, 1]))
|
||||||
|
output.append(sample.reshape([-1]))
|
||||||
|
x = sample.transpose([1, 0, 2])
|
||||||
|
|
||||||
|
elif self.mode == 'RAW':
|
||||||
|
posterior = F.softmax(logits, axis=1)
|
||||||
|
distrib = paddle.distribution.Categorical(posterior)
|
||||||
|
# corresponding operate [np.floor((fx + 1) / 2 * mu + 0.5)] in enocde_mu_law
|
||||||
|
sample = 2 * distrib.sample([1])[0].cast('float32') / (
|
||||||
|
self.n_classes - 1.) - 1.
|
||||||
|
output.append(sample)
|
||||||
|
x = sample.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Unknown model mode value - ', self.mode)
|
||||||
|
|
||||||
|
if gen_display:
|
||||||
|
if i % 1000 == 0:
|
||||||
|
self.gen_display(i, int(seq_len), int(b_size), start)
|
||||||
|
|
||||||
|
output = paddle.stack(output).transpose([1, 0])
|
||||||
|
|
||||||
|
if mu_law:
|
||||||
|
output = decode_mu_law(output, self.n_classes, False)
|
||||||
|
|
||||||
|
if batched:
|
||||||
|
output = self.xfade_and_unfold(output, target, overlap)
|
||||||
|
else:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
# Fade-out at the end to avoid signal cutting out suddenly
|
||||||
|
fade_out = paddle.linspace(1, 0, 20 * self.hop_length)
|
||||||
|
output = output[:wave_len]
|
||||||
|
output[-20 * self.hop_length:] *= fade_out
|
||||||
|
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
# 增加 C_out 维度
|
||||||
|
return output.unsqueeze(-1)
|
||||||
|
|
||||||
|
def get_gru_cell(self, gru):
|
||||||
|
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
||||||
|
gru_cell.weight_hh = gru.weight_hh_l0
|
||||||
|
gru_cell.weight_ih = gru.weight_ih_l0
|
||||||
|
gru_cell.bias_hh = gru.bias_hh_l0
|
||||||
|
gru_cell.bias_ih = gru.bias_ih_l0
|
||||||
|
|
||||||
|
return gru_cell
|
||||||
|
|
||||||
|
def _flatten_parameters(self):
|
||||||
|
[m.flatten_parameters() for m in self._to_flatten]
|
||||||
|
|
||||||
|
def pad_tensor(self, x, pad, side='both'):
|
||||||
|
'''
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
mel, [1, n_frames, 80]
|
||||||
|
pad : int
|
||||||
|
side : str
|
||||||
|
'both', 'before' or 'after'
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
'''
|
||||||
|
b, t, c = paddle.shape(x)
|
||||||
|
total = t + 2 * pad if side == 'both' else t + pad
|
||||||
|
padded = paddle.zeros([b, total, c])
|
||||||
|
if side == 'before' or side == 'both':
|
||||||
|
padded[:, pad:pad + t, :] = x
|
||||||
|
elif side == 'after':
|
||||||
|
padded[:, :t, :] = x
|
||||||
|
return padded
|
||||||
|
|
||||||
|
def fold_with_overlap(self, x, target, overlap):
|
||||||
|
'''
|
||||||
|
Fold the tensor with overlap for quick batched inference.
|
||||||
|
Overlap will be used for crossfading in xfade_and_unfold()
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Upsampled conditioning features. mels or aux
|
||||||
|
shape=(1, T, features)
|
||||||
|
mels: [1, T, 80]
|
||||||
|
aux: [1, T, 128]
|
||||||
|
target : int
|
||||||
|
Target timesteps for each index of batch
|
||||||
|
overlap : int
|
||||||
|
Timesteps for both xfade and rnn warmup
|
||||||
|
overlap = hop_length * 2
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
shape=(num_folds, target + 2 * overlap, features)
|
||||||
|
num_flods = (time_seq - overlap) // (target + overlap)
|
||||||
|
mel: [num_folds, target + 2 * overlap, 80]
|
||||||
|
aux: [num_folds, target + 2 * overlap, 128]
|
||||||
|
|
||||||
|
Details
|
||||||
|
----------
|
||||||
|
x = [[h1, h2, ... hn]]
|
||||||
|
|
||||||
|
Where each h is a vector of conditioning features
|
||||||
|
|
||||||
|
Eg: target=2, overlap=1 with x.size(1)=10
|
||||||
|
|
||||||
|
folded = [[h1, h2, h3, h4],
|
||||||
|
[h4, h5, h6, h7],
|
||||||
|
[h7, h8, h9, h10]]
|
||||||
|
'''
|
||||||
|
|
||||||
|
_, total_len, features = paddle.shape(x)
|
||||||
|
|
||||||
|
# Calculate variables needed
|
||||||
|
num_folds = (total_len - overlap) // (target + overlap)
|
||||||
|
extended_len = num_folds * (overlap + target) + overlap
|
||||||
|
remaining = total_len - extended_len
|
||||||
|
|
||||||
|
# Pad if some time steps poking out
|
||||||
|
if remaining != 0:
|
||||||
|
num_folds += 1
|
||||||
|
padding = target + 2 * overlap - remaining
|
||||||
|
x = self.pad_tensor(x, padding, side='after')
|
||||||
|
|
||||||
|
folded = paddle.zeros([num_folds, target + 2 * overlap, features])
|
||||||
|
|
||||||
|
# Get the values for the folded tensor
|
||||||
|
for i in range(num_folds):
|
||||||
|
start = i * (target + overlap)
|
||||||
|
end = start + target + 2 * overlap
|
||||||
|
folded[i] = x[0][start:end, :]
|
||||||
|
return folded
|
||||||
|
|
||||||
|
def xfade_and_unfold(self, y, target: int=12000, overlap: int=600):
|
||||||
|
''' Applies a crossfade and unfolds into a 1d array.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
y : Tensor
|
||||||
|
Batched sequences of audio samples
|
||||||
|
shape=(num_folds, target + 2 * overlap)
|
||||||
|
dtype=paddle.float64
|
||||||
|
overlap : int
|
||||||
|
Timesteps for both xfade and rnn warmup
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
audio samples in a 1d array
|
||||||
|
shape=(total_len)
|
||||||
|
dtype=paddle.float64
|
||||||
|
|
||||||
|
Details
|
||||||
|
----------
|
||||||
|
y = [[seq1],
|
||||||
|
[seq2],
|
||||||
|
[seq3]]
|
||||||
|
|
||||||
|
Apply a gain envelope at both ends of the sequences
|
||||||
|
|
||||||
|
y = [[seq1_in, seq1_target, seq1_out],
|
||||||
|
[seq2_in, seq2_target, seq2_out],
|
||||||
|
[seq3_in, seq3_target, seq3_out]]
|
||||||
|
|
||||||
|
Stagger and add up the groups of samples:
|
||||||
|
|
||||||
|
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
|
||||||
|
|
||||||
|
'''
|
||||||
|
# num_folds = (total_len - overlap) // (target + overlap)
|
||||||
|
num_folds, length = y.shape
|
||||||
|
target = length - 2 * overlap
|
||||||
|
total_len = num_folds * (target + overlap) + overlap
|
||||||
|
|
||||||
|
# Need some silence for the run warmup
|
||||||
|
slience_len = overlap // 2
|
||||||
|
fade_len = overlap - slience_len
|
||||||
|
slience = paddle.zeros([slience_len], dtype=paddle.float64)
|
||||||
|
linear = paddle.ones([fade_len], dtype=paddle.float64)
|
||||||
|
|
||||||
|
# Equal power crossfade
|
||||||
|
# fade_in increase from 0 to 1, fade_out reduces from 1 to 0
|
||||||
|
t = paddle.linspace(-1, 1, fade_len, dtype=paddle.float64)
|
||||||
|
fade_in = paddle.sqrt(0.5 * (1 + t))
|
||||||
|
fade_out = paddle.sqrt(0.5 * (1 - t))
|
||||||
|
# Concat the silence to the fades
|
||||||
|
fade_out = paddle.concat([linear, fade_out])
|
||||||
|
fade_in = paddle.concat([slience, fade_in])
|
||||||
|
|
||||||
|
# Apply the gain to the overlap samples
|
||||||
|
y[:, :overlap] *= fade_in
|
||||||
|
y[:, -overlap:] *= fade_out
|
||||||
|
|
||||||
|
unfolded = paddle.zeros([total_len], dtype=paddle.float64)
|
||||||
|
|
||||||
|
# Loop to add up all the samples
|
||||||
|
for i in range(num_folds):
|
||||||
|
start = i * (target + overlap)
|
||||||
|
end = start + target + 2 * overlap
|
||||||
|
unfolded[start:end] += y[i]
|
||||||
|
|
||||||
|
return unfolded
|
||||||
|
|
||||||
|
def gen_display(self, i, seq_len, b_size, start):
|
||||||
|
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
|
||||||
|
pbar = self.progbar(i, seq_len)
|
||||||
|
msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | '
|
||||||
|
sys.stdout.write(f"\r{msg}")
|
||||||
|
|
||||||
|
def progbar(self, i, n, size=16):
|
||||||
|
done = int(i * size) // n
|
||||||
|
bar = ''
|
||||||
|
for i in range(size):
|
||||||
|
bar += '█' if i <= done else '░'
|
||||||
|
return bar
|
@ -0,0 +1,203 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import soundfile as sf
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
|
||||||
|
from paddlespeech.t2s.datasets.vocoder_batch_fn import decode_mu_law
|
||||||
|
from paddlespeech.t2s.datasets.vocoder_batch_fn import label_2_float
|
||||||
|
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||||
|
from paddlespeech.t2s.training.reporter import report
|
||||||
|
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_grad_norm(parameters, norm_type: str=2):
|
||||||
|
'''
|
||||||
|
calculate grad norm of mdoel's parameters
|
||||||
|
parameters:
|
||||||
|
model's parameters
|
||||||
|
norm_type: str
|
||||||
|
Returns
|
||||||
|
------------
|
||||||
|
Tensor
|
||||||
|
grad_norm
|
||||||
|
'''
|
||||||
|
|
||||||
|
grad_list = [
|
||||||
|
paddle.to_tensor(p.grad) for p in parameters if p.grad is not None
|
||||||
|
]
|
||||||
|
norm_list = paddle.stack(
|
||||||
|
[paddle.norm(grad, norm_type) for grad in grad_list])
|
||||||
|
total_norm = paddle.norm(norm_list)
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
|
||||||
|
# for save name in gen_valid_samples()
|
||||||
|
ITERATION = 0
|
||||||
|
|
||||||
|
|
||||||
|
class WaveRNNUpdater(StandardUpdater):
|
||||||
|
def __init__(self,
|
||||||
|
model: Layer,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
criterion: Layer,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
init_state=None,
|
||||||
|
output_dir: Path=None,
|
||||||
|
mode='RAW'):
|
||||||
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||||
|
|
||||||
|
self.criterion = criterion
|
||||||
|
# self.scheduler = scheduler
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
|
||||||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
losses_dict = {}
|
||||||
|
# parse batch
|
||||||
|
self.model.train()
|
||||||
|
self.optimizer.clear_grad()
|
||||||
|
|
||||||
|
wav, y, mel = batch
|
||||||
|
|
||||||
|
y_hat = self.model(wav, mel)
|
||||||
|
if self.mode == 'RAW':
|
||||||
|
y_hat = y_hat.transpose([0, 2, 1]).unsqueeze(-1)
|
||||||
|
elif self.mode == 'MOL':
|
||||||
|
y_hat = paddle.cast(y, dtype='float32')
|
||||||
|
|
||||||
|
y = y.unsqueeze(-1)
|
||||||
|
loss = self.criterion(y_hat, y)
|
||||||
|
loss.backward()
|
||||||
|
grad_norm = float(
|
||||||
|
calculate_grad_norm(self.model.parameters(), norm_type=2))
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
report("train/loss", float(loss))
|
||||||
|
report("train/grad_norm", float(grad_norm))
|
||||||
|
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
losses_dict["grad_norm"] = float(grad_norm)
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
global ITERATION
|
||||||
|
ITERATION = self.state.iteration + 1
|
||||||
|
|
||||||
|
|
||||||
|
class WaveRNNEvaluator(StandardEvaluator):
|
||||||
|
def __init__(self,
|
||||||
|
model: Layer,
|
||||||
|
criterion: Layer,
|
||||||
|
dataloader: Optimizer,
|
||||||
|
output_dir: Path=None,
|
||||||
|
valid_generate_loader=None,
|
||||||
|
config=None):
|
||||||
|
super().__init__(model, dataloader)
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
|
self.criterion = criterion
|
||||||
|
self.valid_generate_loader = valid_generate_loader
|
||||||
|
self.config = config
|
||||||
|
self.mode = config.model.mode
|
||||||
|
|
||||||
|
self.valid_samples_dir = output_dir / "valid_samples"
|
||||||
|
self.valid_samples_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
self.msg = "Evaluate: "
|
||||||
|
losses_dict = {}
|
||||||
|
# parse batch
|
||||||
|
wav, y, mel = batch
|
||||||
|
y_hat = self.model(wav, mel)
|
||||||
|
|
||||||
|
if self.mode == 'RAW':
|
||||||
|
y_hat = y_hat.transpose([0, 2, 1]).unsqueeze(-1)
|
||||||
|
elif self.mode == 'MOL':
|
||||||
|
y_hat = paddle.cast(y, dtype='float32')
|
||||||
|
|
||||||
|
y = y.unsqueeze(-1)
|
||||||
|
loss = self.criterion(y_hat, y)
|
||||||
|
report("eval/loss", float(loss))
|
||||||
|
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
|
||||||
|
self.iteration = ITERATION
|
||||||
|
if self.iteration % self.config.gen_eval_samples_interval_steps == 0:
|
||||||
|
self.gen_valid_samples()
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
self.logger.info(self.msg)
|
||||||
|
|
||||||
|
def gen_valid_samples(self):
|
||||||
|
|
||||||
|
for i, (mel, wav) in enumerate(self.valid_generate_loader):
|
||||||
|
if i >= self.config.generate_num:
|
||||||
|
print("before break")
|
||||||
|
break
|
||||||
|
print(
|
||||||
|
'\n| Generating: {}/{}'.format(i + 1, self.config.generate_num))
|
||||||
|
wav = wav[0]
|
||||||
|
if self.mode == 'MOL':
|
||||||
|
bits = 16
|
||||||
|
else:
|
||||||
|
bits = self.config.model.bits
|
||||||
|
if self.config.mu_law and self.mode != 'MOL':
|
||||||
|
wav = decode_mu_law(wav, 2**bits, from_labels=True)
|
||||||
|
else:
|
||||||
|
wav = label_2_float(wav, bits)
|
||||||
|
origin_save_path = self.valid_samples_dir / '{}_steps_{}_target.wav'.format(
|
||||||
|
self.iteration, i)
|
||||||
|
sf.write(origin_save_path, wav.numpy(), samplerate=self.config.fs)
|
||||||
|
|
||||||
|
if self.config.inference.gen_batched:
|
||||||
|
batch_str = 'gen_batched_target{}_overlap{}'.format(
|
||||||
|
self.config.inference.target, self.config.inference.overlap)
|
||||||
|
else:
|
||||||
|
batch_str = 'gen_not_batched'
|
||||||
|
gen_save_path = str(self.valid_samples_dir /
|
||||||
|
'{}_steps_{}_{}.wav'.format(self.iteration, i,
|
||||||
|
batch_str))
|
||||||
|
# (1, C_aux, T) -> (T, C_aux)
|
||||||
|
mel = mel.squeeze(0).transpose([1, 0])
|
||||||
|
gen_sample = self.model.generate(
|
||||||
|
mel, self.config.inference.gen_batched,
|
||||||
|
self.config.inference.target, self.config.inference.overlap,
|
||||||
|
self.config.mu_law)
|
||||||
|
sf.write(
|
||||||
|
gen_save_path, gen_sample.numpy(), samplerate=self.config.fs)
|
Loading…
Reference in new issue