diff --git a/examples/vctk/README.md b/examples/vctk/README.md index 41163dbe..4a589bcc 100644 --- a/examples/vctk/README.md +++ b/examples/vctk/README.md @@ -10,3 +10,4 @@ * voc2 - MelGAN * voc3 - MultiBand MelGAN * ernie_sat - ERNIE-SAT +* vc3 - StarGANv2-VC diff --git a/examples/vctk/vc3/README.md b/examples/vctk/vc3/README.md new file mode 100644 index 00000000..83e1003c --- /dev/null +++ b/examples/vctk/vc3/README.md @@ -0,0 +1,10 @@ +You can download test source audios from [test_wav.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/starganv2vc/test_wav.zip). + + +Test Voice Conversion: + +```bash +wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/starganv2vc/test_wav.zip +unzip test_wav.zip +./run.sh --stage 2 --stop-stage 2 --gpus 0 +``` \ No newline at end of file diff --git a/examples/vctk/vc3/conf/default.yaml b/examples/vctk/vc3/conf/default.yaml new file mode 100644 index 00000000..0acc2a56 --- /dev/null +++ b/examples/vctk/vc3/conf/default.yaml @@ -0,0 +1,22 @@ + generator_params: + dim_in: 64 + style_dim: 64 + max_conv_dim: 512 + w_hpf: 0 + F0_channel: 256 + mapping_network_params: + num_domains: 20 # num of speakers in StarGANv2 + latent_dim: 16 + style_dim: 64 # same as style_dim in generator_params + hidden_dim: 512 # same as max_conv_dim in generator_params + style_encoder_params: + dim_in: 64 # same as dim_in in generator_params + style_dim: 64 # same as style_dim in generator_params + num_domains: 20 # same as num_domains in generator_params + max_conv_dim: 512 # same as max_conv_dim in generator_params + discriminator_params: + dim_in: 64 # same as dim_in in generator_params + num_domains: 20 # same as num_domains in mapping_network_params + max_conv_dim: 512 # same as max_conv_dim in generator_params + n_repeat: 4 + \ No newline at end of file diff --git a/examples/vctk/vc3/local/preprocess.sh b/examples/vctk/vc3/local/preprocess.sh new file mode 100755 index 00000000..ea0fbc43 --- /dev/null +++ b/examples/vctk/vc3/local/preprocess.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + +fi diff --git a/examples/vctk/vc3/local/train.sh b/examples/vctk/vc3/local/train.sh new file mode 100755 index 00000000..3a507650 --- /dev/null +++ b/examples/vctk/vc3/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=1 \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt diff --git a/examples/vctk/vc3/local/voice_conversion.sh b/examples/vctk/vc3/local/voice_conversion.sh new file mode 100755 index 00000000..edf8f7ef --- /dev/null +++ b/examples/vctk/vc3/local/voice_conversion.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +config_path=$1 +source_path=$2 +output_dir=$3 + +python3 ${BIN_DIR}/vc.py \ + --config_path=${config_path} \ + --source_path=${source_path}\ + --output_dir=${output_dir} \ No newline at end of file diff --git a/examples/vctk/vc3/path.sh b/examples/vctk/vc3/path.sh new file mode 100755 index 00000000..9de2e4d7 --- /dev/null +++ b/examples/vctk/vc3/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=starganv2_vc +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} diff --git a/examples/vctk/vc3/run.sh b/examples/vctk/vc3/run.sh new file mode 100755 index 00000000..602a593d --- /dev/null +++ b/examples/vctk/vc3/run.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_331.pdz +source_path=test_wav/goat_01.wav +output_dir=vc_output + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +# not ready now +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi +# not ready now +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan by default + CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_conversion.sh ${conf_path} ${source_path} ${output_dir}|| exit -1 +fi + diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 96010610..7624b735 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -2003,7 +2003,7 @@ g2pw_onnx_models = { } # --------------------------------- -# ------------- Rhy_frontend --------------- +# ---------- Rhy_frontend --------- # --------------------------------- rhy_frontend_models = { 'rhy_e2e': { @@ -2014,3 +2014,16 @@ rhy_frontend_models = { }, }, } + +# --------------------------------- +# ---------- StarGANv2VC ---------- +# --------------------------------- + +StarGANv2VC_source = { + '1.0' :{ + 'url': 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/starganv2vc/StarGANv2VC_source.zip', + 'md5': '195e169419163f5648030ba84c71f866', + + } +} + diff --git a/paddlespeech/t2s/datasets/get_feats.py b/paddlespeech/t2s/datasets/get_feats.py index a90f1a41..5ec97b81 100644 --- a/paddlespeech/t2s/datasets/get_feats.py +++ b/paddlespeech/t2s/datasets/get_feats.py @@ -17,6 +17,11 @@ import numpy as np import pyworld from scipy.interpolate import interp1d +from typing import Optional +from typing import Union +from typing_extensions import Literal + + class LogMelFBank(): def __init__(self, @@ -27,7 +32,10 @@ class LogMelFBank(): window: str="hann", n_mels: int=80, fmin: int=80, - fmax: int=7600): + fmax: int=7600, + norm: Optional[Union[Literal["slaney"], float]]="slaney", + htk: bool=False, + power: float=1.0): self.sr = sr # stft self.n_fft = n_fft @@ -36,11 +44,14 @@ class LogMelFBank(): self.window = window self.center = True self.pad_mode = "reflect" + self.norm = norm + self.htk = htk # mel self.n_mels = n_mels self.fmin = 0 if fmin is None else fmin self.fmax = sr / 2 if fmax is None else fmax + self.power = power self.mel_filter = self._create_mel_filter() @@ -50,7 +61,9 @@ class LogMelFBank(): n_fft=self.n_fft, n_mels=self.n_mels, fmin=self.fmin, - fmax=self.fmax) + fmax=self.fmax, + norm=self.norm, + htk=self.htk) return mel_filter def _stft(self, wav: np.ndarray): @@ -66,7 +79,7 @@ class LogMelFBank(): def _spectrogram(self, wav: np.ndarray): D = self._stft(wav) - return np.abs(D) + return np.abs(D) ** self.power def _mel_spectrogram(self, wav: np.ndarray): S = self._spectrogram(wav) diff --git a/paddlespeech/t2s/exps/starganv2_vc/__init__.py b/paddlespeech/t2s/exps/starganv2_vc/__init__.py new file mode 100644 index 00000000..595add0a --- /dev/null +++ b/paddlespeech/t2s/exps/starganv2_vc/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/exps/starganv2_vc/vc.py b/paddlespeech/t2s/exps/starganv2_vc/vc.py new file mode 100644 index 00000000..ffb25741 --- /dev/null +++ b/paddlespeech/t2s/exps/starganv2_vc/vc.py @@ -0,0 +1,253 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import time +from pathlib import Path + +import librosa +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.resource.pretrained_models import StarGANv2VC_source +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator +from paddlespeech.t2s.models.starganv2_vc import Generator +from paddlespeech.t2s.models.starganv2_vc import JDCNet +from paddlespeech.t2s.models.starganv2_vc import MappingNetwork +from paddlespeech.t2s.models.starganv2_vc import StyleEncoder +from paddlespeech.utils.env import MODEL_HOME + + +def get_mel_extractor(): + sr = 16000 + n_fft = 2048 + win_length = 1200 + hop_length = 300 + n_mels = 80 + fmin = 0 + fmax = sr // 2 + + mel_extractor = LogMelFBank( + sr=sr, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + norm=None, + htk=True, + power=2.0) + return mel_extractor + + +def preprocess(wave, mel_extractor): + logmel = mel_extractor.get_log_mel_fbank(wave, base='e') + # [1, 80, 1011] + mean, std = -4, 4 + mel_tensor = (paddle.to_tensor(logmel.T).unsqueeze(0) - mean) / std + return mel_tensor + + +def compute_style(speaker_dicts, mel_extractor, style_encoder, mapping_network): + reference_embeddings = {} + for key, (path, speaker) in speaker_dicts.items(): + if path == '': + label = paddle.to_tensor([speaker], dtype=paddle.int64) + latent_dim = mapping_network.shared[0].weight.shape[0] + ref = mapping_network(paddle.randn([1, latent_dim]), label) + else: + wave, sr = librosa.load(path, sr=24000) + audio, index = librosa.effects.trim(wave, top_db=30) + if sr != 24000: + wave = librosa.resample(wave, sr, 24000) + mel_tensor = preprocess(wave=wave, mel_extractor=mel_extractor) + with paddle.no_grad(): + label = paddle.to_tensor([speaker], dtype=paddle.int64) + ref = style_encoder(mel_tensor.unsqueeze(1), label) + reference_embeddings[key] = (ref, label) + + return reference_embeddings + + +def get_models(args, uncompress_path): + model_dict = {} + jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') + voc_model_dir = os.path.join(uncompress_path, 'Vocoder/') + starganv2vc_model_dir = os.path.join(uncompress_path, 'starganv2vc.pdz') + + F0_model = JDCNet(num_class=1, seq_len=192) + F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params']) + F0_model.eval() + + voc_config_path = os.path.join(voc_model_dir, 'config.yml') + with open(voc_config_path) as f: + voc_config = CfgNode(yaml.safe_load(f)) + voc_config["generator_params"].pop("upsample_net") + voc_config["generator_params"]["upsample_scales"] = voc_config[ + "generator_params"].pop("upsample_params")["upsample_scales"] + vocoder = PWGGenerator(**voc_config["generator_params"]) + vocoder.remove_weight_norm() + vocoder.eval() + voc_model_path = os.path.join(voc_model_dir, 'checkpoint-400000steps.pd') + vocoder.set_state_dict(paddle.load(voc_model_path)) + + with open(args.config_path) as f: + config = CfgNode(yaml.safe_load(f)) + + generator = Generator(**config['generator_params']) + mapping_network = MappingNetwork(**config['mapping_network_params']) + style_encoder = StyleEncoder(**config['style_encoder_params']) + + starganv2vc_model_param = paddle.load(starganv2vc_model_dir) + + generator.set_state_dict(starganv2vc_model_param['generator_params']) + mapping_network.set_state_dict( + starganv2vc_model_param['mapping_network_params']) + style_encoder.set_state_dict( + starganv2vc_model_param['style_encoder_params']) + + generator.eval() + mapping_network.eval() + style_encoder.eval() + + model_dict['F0_model'] = F0_model + model_dict['vocoder'] = vocoder + model_dict['generator'] = generator + model_dict['mapping_network'] = mapping_network + model_dict['style_encoder'] = style_encoder + return model_dict + + +def voice_conversion(args, uncompress_path): + speakers = [ + 225, 228, 229, 230, 231, 233, 236, 239, 240, 244, 226, 227, 232, 243, + 254, 256, 258, 259, 270, 273 + ] + demo_dir = os.path.join(uncompress_path, 'Demo/VCTK-corpus/') + model_dict = get_models(args, uncompress_path=uncompress_path) + style_encoder = model_dict['style_encoder'] + mapping_network = model_dict['mapping_network'] + generator = model_dict['generator'] + vocoder = model_dict['vocoder'] + F0_model = model_dict['F0_model'] + + # 计算 Demo 文件夹下的说话人的风格 + speaker_dicts = {} + selected_speakers = [273, 259, 258, 243, 254, 244, 236, 233, 230, 228] + for s in selected_speakers: + k = s + speaker_dicts['p' + str(s)] = ( + demo_dir + 'p' + str(k) + '/p' + str(k) + '_023.wav', + speakers.index(s)) + mel_extractor = get_mel_extractor() + reference_embeddings = compute_style( + speaker_dicts=speaker_dicts, + mel_extractor=mel_extractor, + style_encoder=style_encoder, + mapping_network=mapping_network) + + wave, sr = librosa.load(args.source_path, sr=24000) + source = preprocess(wave=wave, mel_extractor=mel_extractor) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + orig_wav_name = str(output_dir / 'orig_voc.wav') + print('原始语音 (使用声码器解码): %s' % orig_wav_name) + c = source.transpose([0, 2, 1]).squeeze() + with paddle.no_grad(): + recon = vocoder.inference(c) + recon = recon.reshape([-1]).numpy() + sf.write(orig_wav_name, recon, samplerate=24000) + + keys = [] + converted_samples = {} + reconstructed_samples = {} + converted_mels = {} + start = time.time() + + for key, (ref, _) in reference_embeddings.items(): + with paddle.no_grad(): + # F0_model 输入的特征是否可以不带 norm,或者 norm 是否一定要和 stargan 原作保持一致? + # !! 需要,ASR 和 F0_model 用的是一样的数据预处理方式 + # 如果不想要重新训练 ASR 和 F0_model, 则我们的数据预处理需要和 stargan 原作保持一致 + # 但是 vocoder 就无法复用 + # 是否因为 asr 的输入是 16k 的,所以 torchaudio 的参数也是 16k 的? + f0_feat = F0_model.get_feature_GAN(source.unsqueeze(1)) + # 输出是带 norm 的 mel, 所以可以直接用 vocoder.inference + out = generator(source.unsqueeze(1), ref, F0=f0_feat) + c = out.transpose([0, 1, 3, 2]).squeeze() + y_out = vocoder.inference(c) + y_out = y_out.reshape([-1]) + if key not in speaker_dicts or speaker_dicts[key][0] == "": + recon = None + else: + wave, sr = librosa.load(speaker_dicts[key][0], sr=24000) + mel = preprocess(wave=wave, mel_extractor=mel_extractor) + c = mel.transpose([0, 2, 1]).squeeze() + recon = vocoder.inference(c) + recon = recon.reshape([-1]).numpy() + + converted_samples[key] = y_out.numpy() + reconstructed_samples[key] = recon + converted_mels[key] = out + keys.append(key) + end = time.time() + print('总共花费时间: %.3f sec' % (end - start)) + for key, wave in converted_samples.items(): + wav_name = str(output_dir / ('vc_result_' + key + '.wav')) + print('语音转换结果: %s' % wav_name) + sf.write(wav_name, wave, samplerate=24000) + ref_wav_name = str(output_dir / ('ref_voc_' + key + '.wav')) + print('参考的说话人 (使用声码器解码): %s' % ref_wav_name) + if reconstructed_samples[key] is not None: + sf.write(ref_wav_name, reconstructed_samples[key], samplerate=24000) + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser( + description="StarGANv2-VC Voice Conversion.") + parser.add_argument("--source_path", type=str, help="source audio's path.") + parser.add_argument("--output_dir", type=str, help="output dir.") + parser.add_argument( + '--config_path', + type=str, + default=None, + help='Config of StarGANv2-VC model.') + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + model_version = '1.0' + uncompress_path = download_and_decompress(StarGANv2VC_source[model_version], + MODEL_HOME) + voice_conversion(args, uncompress_path=uncompress_path) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/__init__.py b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/__init__.py new file mode 100644 index 00000000..595add0a --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/config.yml b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/config.yml new file mode 100644 index 00000000..9f4f9594 --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/config.yml @@ -0,0 +1,29 @@ +log_dir: "logs" +save_freq: 20 +device: "cuda" +epochs: 180 +batch_size: 48 +pretrained_model: "" +train_data: "asr_train_list.txt" +val_data: "asr_val_list.txt" + +dataset_params: + data_augmentation: true + +preprocess_parasm: + sr: 24000 + spect_params: + n_fft: 2048 + win_length: 1200 + hop_length: 300 + mel_params: + n_mels: 80 + +model_params: + input_dim: 80 + hidden_dim: 256 + n_token: 80 + token_embedding_dim: 256 + +optimizer_params: + lr: 0.0005 \ No newline at end of file diff --git a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py new file mode 100644 index 00000000..71b9753c --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py @@ -0,0 +1,480 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +import paddle +import paddle.nn.functional as F +import paddleaudio.functional as audio_F +from paddle import nn + +from paddlespeech.utils.initialize import _calculate_gain +from paddlespeech.utils.initialize import xavier_uniform_ + + +def _get_activation_fn(activ): + if activ == 'relu': + return nn.ReLU() + elif activ == 'lrelu': + return nn.LeakyReLU(0.2) + elif activ == 'swish': + return nn.Swish() + else: + raise RuntimeError( + 'Unexpected activ type %s, expected [relu, lrelu, swish]' % activ) + + +class LinearNorm(nn.Layer): + def __init__(self, + in_dim: int, + out_dim: int, + bias: bool=True, + w_init_gain: str='linear'): + super().__init__() + self.linear_layer = nn.Linear(in_dim, out_dim, bias_attr=bias) + xavier_uniform_( + self.linear_layer.weight, gain=_calculate_gain(w_init_gain)) + + def forward(self, x: paddle.Tensor): + return self.linear_layer(x) + + +class ConvNorm(nn.Layer): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int=1, + stride: int=1, + padding: int=None, + dilation: int=1, + bias: bool=True, + w_init_gain: str='linear', + param=None): + super().__init__() + if padding is None: + assert (kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = nn.Conv1D( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias_attr=bias) + + xavier_uniform_( + self.conv.weight, gain=_calculate_gain(w_init_gain, param=param)) + + def forward(self, signal: paddle.Tensor): + conv_signal = self.conv(signal) + return conv_signal + + +class CausualConv(nn.Layer): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int=1, + stride: int=1, + padding: int=1, + dilation: int=1, + bias: bool=True, + w_init_gain: str='linear', + param=None): + super().__init__() + if padding is None: + assert (kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) * 2 + else: + self.padding = padding * 2 + self.conv = nn.Conv1D( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + dilation=dilation, + bias_attr=bias) + + xavier_uniform_( + self.conv.weight, gain=_calculate_gain(w_init_gain, param=param)) + + def forward(self, x: paddle.Tensor): + x = self.conv(x) + x = x[:, :, :-self.padding] + return x + + +class CausualBlock(nn.Layer): + def __init__(self, + hidden_dim: int, + n_conv: int=3, + dropout_p: float=0.2, + activ: str='lrelu'): + super().__init__() + self.blocks = nn.LayerList([ + self._get_conv( + hidden_dim=hidden_dim, + dilation=3**i, + activ=activ, + dropout_p=dropout_p) for i in range(n_conv) + ]) + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, + hidden_dim: int, + dilation: int, + activ: str='lrelu', + dropout_p: float=0.2): + layers = [ + CausualConv( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + padding=dilation, + dilation=dilation), _get_activation_fn(activ), + nn.BatchNorm1D(hidden_dim), nn.Dropout(p=dropout_p), CausualConv( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + padding=1, + dilation=1), _get_activation_fn(activ), nn.Dropout(p=dropout_p) + ] + return nn.Sequential(*layers) + + +class ConvBlock(nn.Layer): + def __init__(self, + hidden_dim: int, + n_conv: int=3, + dropout_p: float=0.2, + activ: str='relu'): + super().__init__() + self._n_groups = 8 + self.blocks = nn.LayerList([ + self._get_conv( + hidden_dim=hidden_dim, + dilation=3**i, + activ=activ, + dropout_p=dropout_p) for i in range(n_conv) + ]) + + def forward(self, x: paddle.Tensor): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, + hidden_dim: int, + dilation: int, + activ: str='relu', + dropout_p: float=0.2): + layers = [ + ConvNorm( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + padding=dilation, + dilation=dilation), _get_activation_fn(activ), + nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), + nn.Dropout(p=dropout_p), ConvNorm( + hidden_dim, hidden_dim, kernel_size=3, padding=1, + dilation=1), _get_activation_fn(activ), nn.Dropout(p=dropout_p) + ] + return nn.Sequential(*layers) + + +class LocationLayer(nn.Layer): + def __init__(self, + attention_n_filters: int, + attention_kernel_size: int, + attention_dim: int): + super().__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = ConvNorm( + in_channels=2, + out_channels=attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, + bias=False, + stride=1, + dilation=1) + self.location_dense = LinearNorm( + in_dim=attention_n_filters, + out_dim=attention_dim, + bias=False, + w_init_gain='tanh') + + def forward(self, attention_weights_cat: paddle.Tensor): + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose([0, 2, 1]) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class Attention(nn.Layer): + def __init__(self, + attention_rnn_dim: int, + embedding_dim: int, + attention_dim: int, + attention_location_n_filters: int, + attention_location_kernel_size: int): + super().__init__() + self.query_layer = LinearNorm( + in_dim=attention_rnn_dim, + out_dim=attention_dim, + bias=False, + w_init_gain='tanh') + self.memory_layer = LinearNorm( + in_dim=embedding_dim, + out_dim=attention_dim, + bias=False, + w_init_gain='tanh') + self.v = LinearNorm(in_dim=attention_dim, out_dim=1, bias=False) + self.location_layer = LocationLayer( + attention_n_filters=attention_location_n_filters, + attention_kernel_size=attention_location_kernel_size, + attention_dim=attention_dim) + self.score_mask_value = -float("inf") + + def get_alignment_energies(self, + query: paddle.Tensor, + processed_memory: paddle.Tensor, + attention_weights_cat: paddle.Tensor): + """ + Args: + query: + decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: + processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: + cumulative and prev. att weights (B, 2, max_time) + Returns: + Tensor: alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v( + paddle.tanh(processed_query + processed_attention_weights + + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, + attention_hidden_state: paddle.Tensor, + memory: paddle.Tensor, + processed_memory: paddle.Tensor, + attention_weights_cat: paddle.Tensor, + mask: paddle.Tensor): + """ + Args: + attention_hidden_state: + attention rnn last output + memory: + encoder outputs + processed_memory: + processed encoder outputs + attention_weights_cat: + previous and cummulative attention weights + mask: + binary mask for padded data + """ + alignment = self.get_alignment_energies( + query=attention_hidden_state, + processed_memory=processed_memory, + attention_weights_cat=attention_weights_cat) + + if mask is not None: + alignment.data.masked_fill_(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, axis=1) + attention_context = paddle.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class ForwardAttentionV2(nn.Layer): + def __init__(self, + attention_rnn_dim: int, + embedding_dim: int, + attention_dim: int, + attention_location_n_filters: int, + attention_location_kernel_size: int): + super().__init__() + self.query_layer = LinearNorm( + in_dim=attention_rnn_dim, + out_dim=attention_dim, + bias=False, + w_init_gain='tanh') + self.memory_layer = LinearNorm( + in_dim=embedding_dim, + out_dim=attention_dim, + bias=False, + w_init_gain='tanh') + self.v = LinearNorm(in_dim=attention_dim, out_dim=1, bias=False) + self.location_layer = LocationLayer( + attention_n_filters=attention_location_n_filters, + attention_kernel_size=attention_location_kernel_size, + attention_dim=attention_dim) + self.score_mask_value = -float(1e20) + + def get_alignment_energies(self, + query: paddle.Tensor, + processed_memory: paddle.Tensor, + attention_weights_cat: paddle.Tensor): + """ + Args: + query: + decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: + processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: + prev. and cumulative att weights (B, 2, max_time) + Returns: + Tensor: alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v( + paddle.tanh(processed_query + processed_attention_weights + + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, + attention_hidden_state: paddle.Tensor, + memory: paddle.Tensor, + processed_memory: paddle.Tensor, + attention_weights_cat: paddle.Tensor, + mask: paddle.Tensor, + log_alpha: paddle.Tensor): + """ + Args: + attention_hidden_state: + attention rnn last output + memory: + encoder outputs + processed_memory: + processed encoder outputs + attention_weights_cat: + previous and cummulative attention weights + mask: + binary mask for padded data + """ + log_energy = self.get_alignment_energies( + query=attention_hidden_state, + processed_memory=processed_memory, + attention_weights_cat=attention_weights_cat) + + if mask is not None: + log_energy[:] = paddle.where( + mask, + paddle.full(log_energy.shape, self.score_mask_value, + log_energy.dtype), log_energy) + log_alpha_shift_padded = [] + max_time = log_energy.shape[1] + for sft in range(2): + shifted = log_alpha[:, :max_time - sft] + shift_padded = F.pad(shifted, (sft, 0), 'constant', + self.score_mask_value) + log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) + + biased = paddle.logsumexp(paddle.conat(log_alpha_shift_padded, 2), 2) + log_alpha_new = biased + log_energy + attention_weights = F.softmax(log_alpha_new, axis=1) + attention_context = paddle.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights, log_alpha_new + + +class PhaseShuffle2D(nn.Layer): + def __init__(self, n: int=2): + super().__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x: paddle.Tensor, move: int=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :, :move] + right = x[:, :, :, move:] + shuffled = paddle.concat([right, left], axis=3) + return shuffled + + +class PhaseShuffle1D(nn.Layer): + def __init__(self, n: int=2): + super().__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x: paddle.Tensor, move: int=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :move] + right = x[:, :, move:] + shuffled = paddle.concat([right, left], axis=2) + + return shuffled + + +class MFCC(nn.Layer): + def __init__(self, n_mfcc: int=40, n_mels: int=80): + super().__init__() + self.n_mfcc = n_mfcc + self.n_mels = n_mels + self.norm = 'ortho' + dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) + self.register_buffer('dct_mat', dct_mat) + + def forward(self, mel_specgram: paddle.Tensor): + if len(mel_specgram.shape) == 2: + mel_specgram = mel_specgram.unsqueeze(0) + unsqueezed = True + else: + unsqueezed = False + # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (channel, time, n_mfcc).tranpose(...) + mfcc = paddle.matmul(mel_specgram.transpose([0, 2, 1]), + self.dct_mat).transpose([0, 2, 1]) + + # unpack batch + if unsqueezed: + mfcc = mfcc.squeeze(0) + return mfcc diff --git a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py new file mode 100644 index 00000000..48de8af1 --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import paddle +import paddle.nn.functional as F +from paddle import nn + +from .layers import Attention +from .layers import ConvBlock +from .layers import ConvNorm +from .layers import LinearNorm +from .layers import MFCC +from paddlespeech.utils.initialize import uniform_ + + +class ASRCNN(nn.Layer): + def __init__( + self, + input_dim: int=80, + hidden_dim: int=256, + n_token: int=35, + n_layers: int=6, + token_embedding_dim: int=256, ): + super().__init__() + self.n_token = n_token + self.n_down = 1 + self.to_mfcc = MFCC() + self.init_cnn = ConvNorm( + in_channels=input_dim // 2, + out_channels=hidden_dim, + kernel_size=7, + padding=3, + stride=2) + self.cnns = nn.Sequential(* [ + nn.Sequential( + ConvBlock(hidden_dim), + nn.GroupNorm(num_groups=1, num_channels=hidden_dim)) + for n in range(n_layers) + ]) + self.projection = ConvNorm( + in_channels=hidden_dim, out_channels=hidden_dim // 2) + self.ctc_linear = nn.Sequential( + LinearNorm(in_dim=hidden_dim // 2, out_dim=hidden_dim), + nn.ReLU(), LinearNorm(in_dim=hidden_dim, out_dim=n_token)) + self.asr_s2s = ASRS2S( + embedding_dim=token_embedding_dim, + hidden_dim=hidden_dim // 2, + n_token=n_token) + + def forward(self, + x: paddle.Tensor, + src_key_padding_mask: paddle.Tensor=None, + text_input: paddle.Tensor=None): + x = self.to_mfcc(x) + x = self.init_cnn(x) + x = self.cnns(x) + x = self.projection(x) + x = x.transpose([0, 2, 1]) + ctc_logit = self.ctc_linear(x) + if text_input is not None: + _, s2s_logit, s2s_attn = self.asr_s2s( + memory=x, + memory_mask=src_key_padding_mask, + text_input=text_input) + return ctc_logit, s2s_logit, s2s_attn + else: + return ctc_logit + + def get_feature(self, x: paddle.Tensor): + x = self.to_mfcc(x.squeeze(1)) + x = self.init_cnn(x) + x = self.cnns(x) + x = self.projection(x) + return x + + def length_to_mask(self, lengths: paddle.Tensor): + mask = paddle.arange(lengths.max()).unsqueeze(0).expand( + (lengths.shape[0], -1)).astype(lengths.dtype) + mask = paddle.greater_than(mask + 1, lengths.unsqueeze(1)) + return mask + + def get_future_mask(self, out_length: int, unmask_future_steps: int=0): + """ + Args: + out_length (int): + returned mask shape is (out_length, out_length). + unmask_futre_steps (int): + unmasking future step size. + Return: + mask (paddle.BoolTensor): + mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False + """ + index_tensor = paddle.arange(out_length).unsqueeze(0).expand( + [out_length, -1]) + mask = paddle.greater_than(index_tensor, + index_tensor.T + unmask_future_steps) + return mask + + +class ASRS2S(nn.Layer): + def __init__(self, + embedding_dim: int=256, + hidden_dim: int=512, + n_location_filters: int=32, + location_kernel_size: int=63, + n_token: int=40): + super().__init__() + self.embedding = nn.Embedding(n_token, embedding_dim) + val_range = math.sqrt(6 / hidden_dim) + uniform_(self.embedding.weight, -val_range, val_range) + + self.decoder_rnn_dim = hidden_dim + self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) + self.attention_layer = Attention( + attention_rnn_dim=self.decoder_rnn_dim, + embedding_dim=hidden_dim, + attention_dim=hidden_dim, + attention_location_n_filters=n_location_filters, + attention_location_kernel_size=location_kernel_size) + self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, + self.decoder_rnn_dim) + self.project_to_hidden = nn.Sequential( + LinearNorm(in_dim=self.decoder_rnn_dim * 2, out_dim=hidden_dim), + nn.Tanh()) + self.sos = 1 + self.eos = 2 + + def initialize_decoder_states(self, + memory: paddle.Tensor, + mask: paddle.Tensor): + """ + moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) + """ + B, L, H = memory.shape + dtype = memory.dtype + self.decoder_hidden = paddle.zeros( + (B, self.decoder_rnn_dim)).astype(dtype) + self.decoder_cell = paddle.zeros( + (B, self.decoder_rnn_dim)).astype(dtype) + self.attention_weights = paddle.zeros((B, L)).astype(dtype) + self.attention_weights_cum = paddle.zeros((B, L)).astype(dtype) + self.attention_context = paddle.zeros((B, H)).astype(dtype) + self.memory = memory + self.processed_memory = self.attention_layer.memory_layer(memory) + self.mask = mask + self.unk_index = 3 + self.random_mask = 0.1 + + def forward(self, + memory: paddle.Tensor, + memory_mask: paddle.Tensor, + text_input: paddle.Tensor): + """ + moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) + moemory_mask.shape = (B, L, ) + texts_input.shape = (B, T) + """ + self.initialize_decoder_states(memory, memory_mask) + # text random mask + random_mask = (paddle.rand(text_input.shape) < self.random_mask) + _text_input = text_input.clone() + _text_input[:] = paddle.where( + condition=random_mask, + x=paddle.full( + shape=_text_input.shape, + fill_value=self.unk_index, + dtype=_text_input.dtype), + y=_text_input) + decoder_inputs = self.embedding(_text_input).transpose( + [1, 0, 2]) # -> [T, B, channel] + start_embedding = self.embedding( + paddle.to_tensor( + [self.sos] * decoder_inputs.shape[1], dtype=paddle.long)) + decoder_inputs = paddle.concat( + (start_embedding.unsqueeze(0), decoder_inputs), axis=0) + + hidden_outputs, logit_outputs, alignments = [], [], [] + while len(hidden_outputs) < decoder_inputs.shape[0]: + decoder_input = decoder_inputs[len(hidden_outputs)] + hidden, logit, attention_weights = self.decode(decoder_input) + hidden_outputs += [hidden] + logit_outputs += [logit] + alignments += [attention_weights] + + hidden_outputs, logit_outputs, alignments = \ + self.parse_decoder_outputs( + hidden_outputs, logit_outputs, alignments) + + return hidden_outputs, logit_outputs, alignments + + def decode(self, decoder_input: paddle.Tensor): + cell_input = paddle.concat((decoder_input, self.attention_context), -1) + self.decoder_rnn.flatten_parameters() + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + cell_input, (self.decoder_hidden, self.decoder_cell)) + + attention_weights_cat = paddle.concat( + (self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1)), + axis=1) + + self.attention_context, self.attention_weights = self.attention_layer( + self.decoder_hidden, self.memory, self.processed_memory, + attention_weights_cat, self.mask) + + self.attention_weights_cum += self.attention_weights + + hidden_and_context = paddle.concat( + (self.decoder_hidden, self.attention_context), -1) + hidden = self.project_to_hidden(hidden_and_context) + + # dropout to increasing g + logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training)) + + return hidden, logit, self.attention_weights + + def parse_decoder_outputs(self, + hidden: paddle.Tensor, + logit: paddle.Tensor, + alignments: paddle.Tensor): + # -> [B, T_out + 1, max_time] + alignments = paddle.stack(alignments).transpose([1, 0, 2]) + # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols] + logit = paddle.stack(logit).transpose([1, 0, 2]) + hidden = paddle.stack(hidden).transpose([1, 0, 2]) + + return hidden, logit, alignments diff --git a/paddlespeech/t2s/models/starganv2_vc/JDCNet/__init__.py b/paddlespeech/t2s/models/starganv2_vc/JDCNet/__init__.py new file mode 100644 index 00000000..595add0a --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/JDCNet/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/models/starganv2_vc/JDCNet/model.py b/paddlespeech/t2s/models/starganv2_vc/JDCNet/model.py new file mode 100644 index 00000000..118b8f0e --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/JDCNet/model.py @@ -0,0 +1,234 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementation of model from: +Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using +Convolutional Recurrent Neural Networks" (2019) +Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d +""" +import paddle +from paddle import nn + + +class JDCNet(nn.Layer): + """ + Joint Detection and Classification Network model for singing voice melody. + """ + + def __init__(self, + num_class: int=722, + seq_len: int=31, + leaky_relu_slope: float=0.01): + super().__init__() + self.seq_len = seq_len + self.num_class = num_class + + # input = (b, 1, 31, 513), b = batch size + self.conv_block = nn.Sequential( + # out: (b, 64, 31, 513) + nn.Conv2D( + in_channels=1, + out_channels=64, + kernel_size=3, + padding=1, + bias_attr=False), + nn.BatchNorm2D(num_features=64), + nn.LeakyReLU(leaky_relu_slope), + # (b, 64, 31, 513) + nn.Conv2D(64, 64, 3, padding=1, bias_attr=False), ) + + # res blocks + # (b, 128, 31, 128) + self.res_block1 = ResBlock(in_channels=64, out_channels=128) + # (b, 192, 31, 32) + self.res_block2 = ResBlock(in_channels=128, out_channels=192) + # (b, 256, 31, 8) + self.res_block3 = ResBlock(in_channels=192, out_channels=256) + + # pool block + self.pool_block = nn.Sequential( + nn.BatchNorm2D(num_features=256), + nn.LeakyReLU(leaky_relu_slope), + # (b, 256, 31, 2) + nn.MaxPool2D(kernel_size=(1, 4)), + nn.Dropout(p=0.5), ) + + # maxpool layers (for auxiliary network inputs) + # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2) + self.maxpool1 = nn.MaxPool2D(kernel_size=(1, 40)) + # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2) + self.maxpool2 = nn.MaxPool2D(kernel_size=(1, 20)) + # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2) + self.maxpool3 = nn.MaxPool2D(kernel_size=(1, 10)) + + # in = (b, 640, 31, 2), out = (b, 256, 31, 2) + self.detector_conv = nn.Sequential( + nn.Conv2D( + in_channels=640, + out_channels=256, + kernel_size=1, + bias_attr=False), + nn.BatchNorm2D(256), + nn.LeakyReLU(leaky_relu_slope), + nn.Dropout(p=0.5), ) + + # input: (b, 31, 512) - resized from (b, 256, 31, 2) + # output: (b, 31, 512) + self.bilstm_classifier = nn.LSTM( + input_size=512, + hidden_size=256, + time_major=False, + direction='bidirectional') + + # input: (b, 31, 512) - resized from (b, 256, 31, 2) + # output: (b, 31, 512) + self.bilstm_detector = nn.LSTM( + input_size=512, + hidden_size=256, + time_major=False, + direction='bidirectional') + + # input: (b * 31, 512) + # output: (b * 31, num_class) + self.classifier = nn.Linear( + in_features=512, out_features=self.num_class) + + # input: (b * 31, 512) + # output: (b * 31, 2) - binary classifier + self.detector = nn.Linear(in_features=512, out_features=2) + + # initialize weights + self.apply(self.init_weights) + + def get_feature_GAN(self, x: paddle.Tensor): + seq_len = x.shape[-2] + x = x.astype(paddle.float32).transpose([0, 1, 3, 2] if len(x.shape) == 4 + else [0, 2, 1]) + + convblock_out = self.conv_block(x) + + resblock1_out = self.res_block1(convblock_out) + resblock2_out = self.res_block2(resblock1_out) + resblock3_out = self.res_block3(resblock2_out) + poolblock_out = self.pool_block[0](resblock3_out) + poolblock_out = self.pool_block[1](poolblock_out) + + return poolblock_out.transpose([0, 1, 3, 2] if len(poolblock_out.shape) + == 4 else [0, 2, 1]) + + def forward(self, x: paddle.Tensor): + """ + Returns: + classification_prediction, detection_prediction + sizes: (b, 31, 722), (b, 31, 2) + """ + ############################### + # forward pass for classifier # + ############################### + x = x.transpose([0, 1, 3, 2] if len(x.shape) == 4 else + [0, 2, 1]).astype(paddle.float32) + + convblock_out = self.conv_block(x) + + resblock1_out = self.res_block1(convblock_out) + resblock2_out = self.res_block2(resblock1_out) + resblock3_out = self.res_block3(resblock2_out) + + poolblock_out = self.pool_block[0](resblock3_out) + poolblock_out = self.pool_block[1](poolblock_out) + GAN_feature = poolblock_out.transpose([0, 1, 3, 2] if len( + poolblock_out.shape) == 4 else [0, 2, 1]) + poolblock_out = self.pool_block[2](poolblock_out) + + # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512) + classifier_out = poolblock_out.transpose([0, 2, 1, 3]).reshape( + (-1, self.seq_len, 512)) + self.bilstm_classifier.flatten_parameters() + classifier_out, _ = self.bilstm_classifier( + classifier_out) # ignore the hidden states + + classifier_out = classifier_out.reshape((-1, 512)) # (b * 31, 512) + classifier_out = self.classifier(classifier_out) + classifier_out = classifier_out.reshape( + (-1, self.seq_len, self.num_class)) # (b, 31, num_class) + + # sizes: (b, 31, 722), (b, 31, 2) + # classifier output consists of predicted pitch classes per frame + # detector output consists of: (isvoice, notvoice) estimates per frame + return paddle.abs(classifier_out.squeeze()), GAN_feature, poolblock_out + + @staticmethod + def init_weights(m): + if isinstance(m, nn.Linear): + nn.initializer.KaimingUniform()(m.weight) + if m.bias is not None: + nn.initializer.Constant(0)(m.bias) + elif isinstance(m, nn.Conv2D): + nn.initializer.XavierNormal()(m.weight) + elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell): + for p in m.parameters(): + if len(p.shape) >= 2: + nn.initializer.Orthogonal()(p) + else: + nn.initializer.Normal()(p) + + +class ResBlock(nn.Layer): + def __init__(self, + in_channels: int, + out_channels: int, + leaky_relu_slope=0.01): + super().__init__() + self.downsample = in_channels != out_channels + + # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper + self.pre_conv = nn.Sequential( + nn.BatchNorm2D(num_features=in_channels), + nn.LeakyReLU(leaky_relu_slope), + # apply downsampling on the y axis only + nn.MaxPool2D(kernel_size=(1, 2)), ) + + # conv layers + self.conv = nn.Sequential( + nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias_attr=False), + nn.BatchNorm2D(out_channels), + nn.LeakyReLU(leaky_relu_slope), + nn.Conv2D( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias_attr=False), ) + + # 1 x 1 convolution layer to match the feature dimensions + self.conv1by1 = None + if self.downsample: + self.conv1by1 = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias_attr=False) + + def forward(self, x: paddle.Tensor): + x = self.pre_conv(x) + if self.downsample: + x = self.conv(x) + self.conv1by1(x) + else: + x = self.conv(x) + x + return x diff --git a/paddlespeech/t2s/models/starganv2_vc/__init__.py b/paddlespeech/t2s/models/starganv2_vc/__init__.py new file mode 100644 index 00000000..e3327867 --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .starganv2_vc import * +from .starganv2_vc_updater import * +from .AuxiliaryASR.model import * +from .JDCNet.model import * diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py new file mode 100644 index 00000000..8086a595 --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -0,0 +1,255 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn.functional as F +from munch import Munch +from starganv2vc_paddle.transforms import build_transforms + + +# 这些都写到 updater 里 +def compute_d_loss(nets, + args, + x_real, + y_org, + y_trg, + z_trg=None, + x_ref=None, + use_r1_reg=True, + use_adv_cls=False, + use_con_reg=False): + args = Munch(args) + + assert (z_trg is None) != (x_ref is None) + # with real audios + x_real.stop_gradient = False + out = nets.discriminator(x_real, y_org) + loss_real = adv_loss(out, 1) + + # R1 regularizaition (https://arxiv.org/abs/1801.04406v4) + if use_r1_reg: + loss_reg = r1_reg(out, x_real) + else: + loss_reg = paddle.to_tensor([0.], dtype=paddle.float32) + + # consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724) + loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32) + if use_con_reg: + t = build_transforms() + out_aug = nets.discriminator(t(x_real).detach(), y_org) + loss_con_reg += F.smooth_l1_loss(out, out_aug) + + # with fake audios + with paddle.no_grad(): + if z_trg is not None: + s_trg = nets.mapping_network(z_trg, y_trg) + else: # x_ref is not None + s_trg = nets.style_encoder(x_ref, y_trg) + + F0 = nets.f0_model.get_feature_GAN(x_real) + x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0) + out = nets.discriminator(x_fake, y_trg) + loss_fake = adv_loss(out, 0) + if use_con_reg: + out_aug = nets.discriminator(t(x_fake).detach(), y_trg) + loss_con_reg += F.smooth_l1_loss(out, out_aug) + + # adversarial classifier loss + if use_adv_cls: + out_de = nets.discriminator.classifier(x_fake) + loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg], + y_org[y_org != y_trg]) + + if use_con_reg: + out_de_aug = nets.discriminator.classifier(t(x_fake).detach()) + loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug) + else: + loss_real_adv_cls = paddle.zeros([1]).mean() + + loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \ + args.lambda_adv_cls * loss_real_adv_cls + \ + args.lambda_con_reg * loss_con_reg + + return loss, Munch( + real=loss_real.item(), + fake=loss_fake.item(), + reg=loss_reg.item(), + real_adv_cls=loss_real_adv_cls.item(), + con_reg=loss_con_reg.item()) + + +def compute_g_loss(nets, + args, + x_real, + y_org, + y_trg, + z_trgs=None, + x_refs=None, + use_adv_cls=False): + args = Munch(args) + + assert (z_trgs is None) != (x_refs is None) + if z_trgs is not None: + z_trg, z_trg2 = z_trgs + if x_refs is not None: + x_ref, x_ref2 = x_refs + + # compute style vectors + if z_trgs is not None: + s_trg = nets.mapping_network(z_trg, y_trg) + else: + s_trg = nets.style_encoder(x_ref, y_trg) + + # compute ASR/F0 features (real) + with paddle.no_grad(): + F0_real, GAN_F0_real, cyc_F0_real = nets.f0_model(x_real) + ASR_real = nets.asr_model.get_feature(x_real) + + # adversarial loss + x_fake = nets.generator(x_real, s_trg, masks=None, F0=GAN_F0_real) + out = nets.discriminator(x_fake, y_trg) + loss_adv = adv_loss(out, 1) + + # compute ASR/F0 features (fake) + F0_fake, GAN_F0_fake, _ = nets.f0_model(x_fake) + ASR_fake = nets.asr_model.get_feature(x_fake) + + # norm consistency loss + x_fake_norm = log_norm(x_fake) + x_real_norm = log_norm(x_real) + loss_norm = (( + paddle.nn.ReLU()(paddle.abs(x_fake_norm - x_real_norm) - args.norm_bias) + )**2).mean() + + # F0 loss + loss_f0 = f0_loss(F0_fake, F0_real) + + # style F0 loss (style initialization) + if x_refs is not None and args.lambda_f0_sty > 0 and not use_adv_cls: + F0_sty, _, _ = nets.f0_model(x_ref) + loss_f0_sty = F.l1_loss( + compute_mean_f0(F0_fake), compute_mean_f0(F0_sty)) + else: + loss_f0_sty = paddle.zeros([1]).mean() + + # ASR loss + loss_asr = F.smooth_l1_loss(ASR_fake, ASR_real) + + # style reconstruction loss + s_pred = nets.style_encoder(x_fake, y_trg) + loss_sty = paddle.mean(paddle.abs(s_pred - s_trg)) + + # diversity sensitive loss + if z_trgs is not None: + s_trg2 = nets.mapping_network(z_trg2, y_trg) + else: + s_trg2 = nets.style_encoder(x_ref2, y_trg) + x_fake2 = nets.generator(x_real, s_trg2, masks=None, F0=GAN_F0_real) + x_fake2 = x_fake2.detach() + _, GAN_F0_fake2, _ = nets.f0_model(x_fake2) + loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2)) + loss_ds += F.smooth_l1_loss(GAN_F0_fake, GAN_F0_fake2.detach()) + + # cycle-consistency loss + s_org = nets.style_encoder(x_real, y_org) + x_rec = nets.generator(x_fake, s_org, masks=None, F0=GAN_F0_fake) + loss_cyc = paddle.mean(paddle.abs(x_rec - x_real)) + # F0 loss in cycle-consistency loss + if args.lambda_f0 > 0: + _, _, cyc_F0_rec = nets.f0_model(x_rec) + loss_cyc += F.smooth_l1_loss(cyc_F0_rec, cyc_F0_real) + if args.lambda_asr > 0: + ASR_recon = nets.asr_model.get_feature(x_rec) + loss_cyc += F.smooth_l1_loss(ASR_recon, ASR_real) + + # adversarial classifier loss + if use_adv_cls: + out_de = nets.discriminator.classifier(x_fake) + loss_adv_cls = F.cross_entropy(out_de[y_org != y_trg], + y_trg[y_org != y_trg]) + else: + loss_adv_cls = paddle.zeros([1]).mean() + + loss = args.lambda_adv * loss_adv + args.lambda_sty * loss_sty \ + - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc\ + + args.lambda_norm * loss_norm \ + + args.lambda_asr * loss_asr \ + + args.lambda_f0 * loss_f0 \ + + args.lambda_f0_sty * loss_f0_sty \ + + args.lambda_adv_cls * loss_adv_cls + + return loss, Munch( + adv=loss_adv.item(), + sty=loss_sty.item(), + ds=loss_ds.item(), + cyc=loss_cyc.item(), + norm=loss_norm.item(), + asr=loss_asr.item(), + f0=loss_f0.item(), + adv_cls=loss_adv_cls.item()) + + +# for norm consistency loss +def log_norm(x, mean=-4, std=4, axis=2): + """ + normalized log mel -> mel -> norm -> log(norm) + """ + x = paddle.log(paddle.exp(x * std + mean).norm(axis=axis)) + return x + + +# for adversarial loss +def adv_loss(logits, target): + assert target in [1, 0] + if len(logits.shape) > 1: + logits = logits.reshape([-1]) + targets = paddle.full_like(logits, fill_value=target) + logits = logits.clip(min=-10, max=10) # prevent nan + loss = F.binary_cross_entropy_with_logits(logits, targets) + return loss + + +# for R1 regularization loss +def r1_reg(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.shape[0] + grad_dout = paddle.grad( + outputs=d_out.sum(), + inputs=x_in, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + grad_dout2 = grad_dout.pow(2) + assert (grad_dout2.shape == x_in.shape) + reg = 0.5 * grad_dout2.reshape((batch_size, -1)).sum(1).mean(0) + return reg + + +# for F0 consistency loss +def compute_mean_f0(f0): + f0_mean = f0.mean(-1) + f0_mean = f0_mean.expand((f0.shape[-1], f0_mean.shape[0])).transpose( + (1, 0)) # (B, M) + return f0_mean + + +def f0_loss(x_f0, y_f0): + """ + x.shape = (B, 1, M, L): predict + y.shape = (B, 1, M, L): target + """ + # compute the mean + x_mean = compute_mean_f0(x_f0) + y_mean = compute_mean_f0(y_f0) + loss = F.l1_loss(x_f0 / x_mean, y_f0 / y_mean) + return loss diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py new file mode 100644 index 00000000..96e9eda8 --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py @@ -0,0 +1,616 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. +This work is licensed under the Creative Commons Attribution-NonCommercial +4.0 International License. To view a copy of this license, visit +http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. +""" +# import copy +import math + +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.utils.initialize import _calculate_gain +from paddlespeech.utils.initialize import xavier_uniform_ + +# from munch import Munch + + +class DownSample(nn.Layer): + def __init__(self, layer_type: str): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == 'none': + return x + elif self.layer_type == 'timepreserve': + return F.avg_pool2d(x, (2, 1)) + elif self.layer_type == 'half': + return F.avg_pool2d(x, 2) + else: + raise RuntimeError( + 'Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' + % self.layer_type) + + +class UpSample(nn.Layer): + def __init__(self, layer_type: str): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == 'none': + return x + elif self.layer_type == 'timepreserve': + return F.interpolate(x, scale_factor=(2, 1), mode='nearest') + elif self.layer_type == 'half': + return F.interpolate(x, scale_factor=2, mode='nearest') + else: + raise RuntimeError( + 'Got unexpected upsampletype %s, expected is [none, timepreserve, half]' + % self.layer_type) + + +class ResBlk(nn.Layer): + def __init__(self, + dim_in: int, + dim_out: int, + actv: nn.LeakyReLU=nn.LeakyReLU(0.2), + normalize: bool=False, + downsample: str='none'): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = DownSample(layer_type=downsample) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in: int, dim_out: int): + self.conv1 = nn.Conv2D( + in_channels=dim_in, + out_channels=dim_in, + kernel_size=3, + stride=1, + padding=1) + self.conv2 = nn.Conv2D( + in_channels=dim_in, + out_channels=dim_out, + kernel_size=3, + stride=1, + padding=1) + if self.normalize: + self.norm1 = nn.InstanceNorm2D(dim_in) + self.norm2 = nn.InstanceNorm2D(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2D( + in_channels=dim_in, + out_channels=dim_out, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + + def _shortcut(self, x: paddle.Tensor): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = self.downsample(x) + return x + + def _residual(self, x: paddle.Tensor): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + x = self.downsample(x) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x: paddle.Tensor): + x = self._shortcut(x) + self._residual(x) + # unit variance + return x / math.sqrt(2) + + +class AdaIN(nn.Layer): + def __init__(self, style_dim: int, num_features: int): + super().__init__() + self.norm = nn.InstanceNorm2D( + num_features=num_features, weight_attr=False, bias_attr=False) + self.fc = nn.Linear(style_dim, num_features * 2) + + def forward(self, x: paddle.Tensor, s: paddle.Tensor): + if len(s.shape) == 1: + s = s[None] + h = self.fc(s) + h = h.reshape((h.shape[0], h.shape[1], 1, 1)) + gamma, beta = paddle.split(h, 2, axis=1) + return (1 + gamma) * self.norm(x) + beta + + +class AdainResBlk(nn.Layer): + def __init__(self, + dim_in: int, + dim_out: int, + style_dim: int=64, + w_hpf: int=0, + actv: nn.Layer=nn.LeakyReLU(0.2), + upsample: str='none'): + super().__init__() + self.w_hpf = w_hpf + self.actv = actv + self.upsample = UpSample(layer_type=upsample) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in: int, dim_out: int, style_dim: int=64): + self.conv1 = nn.Conv2D( + in_channels=dim_in, + out_channels=dim_out, + kernel_size=3, + stride=1, + padding=1) + self.conv2 = nn.Conv2D( + in_channels=dim_out, + out_channels=dim_out, + kernel_size=3, + stride=1, + padding=1) + self.norm1 = AdaIN(style_dim=style_dim, num_features=dim_in) + self.norm2 = AdaIN(style_dim=style_dim, num_features=dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2D( + in_channels=dim_in, + out_channels=dim_out, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + + def _shortcut(self, x: paddle.Tensor): + x = self.upsample(x) + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x: paddle.Tensor, s: paddle.Tensor): + x = self.norm1(x, s) + x = self.actv(x) + x = self.upsample(x) + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x: paddle.Tensor, s: paddle.Tensor): + out = self._residual(x, s) + if self.w_hpf == 0: + out = (out + self._shortcut(x)) / math.sqrt(2) + return out + + +class HighPass(nn.Layer): + def __init__(self, w_hpf: int): + super().__init__() + self.filter = paddle.to_tensor([[-1, -1, -1], [-1, 8., -1], + [-1, -1, -1]]) / w_hpf + + def forward(self, x: paddle.Tensor): + filter = self.filter.unsqueeze(0).unsqueeze(1).tile( + [x.shape[1], 1, 1, 1]) + return F.conv2d(x, filter, padding=1, groups=x.shape[1]) + + +class Generator(nn.Layer): + def __init__(self, + dim_in: int=48, + style_dim: int=48, + max_conv_dim: int=48 * 8, + w_hpf: int=1, + F0_channel: int=0): + super().__init__() + + self.stem = nn.Conv2D( + in_channels=1, + out_channels=dim_in, + kernel_size=3, + stride=1, + padding=1) + self.encode = nn.LayerList() + self.decode = nn.LayerList() + self.to_out = nn.Sequential( + nn.InstanceNorm2D(dim_in), + nn.LeakyReLU(0.2), + nn.Conv2D( + in_channels=dim_in, + out_channels=1, + kernel_size=1, + stride=1, + padding=0)) + self.F0_channel = F0_channel + # down/up-sampling blocks + # int(np.log2(img_size)) - 4 + repeat_num = 4 + if w_hpf > 0: + repeat_num += 1 + + for lid in range(repeat_num): + if lid in [1, 3]: + _downtype = 'timepreserve' + else: + _downtype = 'half' + + dim_out = min(dim_in * 2, max_conv_dim) + self.encode.append( + ResBlk( + dim_in=dim_in, + dim_out=dim_out, + normalize=True, + downsample=_downtype)) + (self.decode.insert if lid else + lambda i, sublayer: self.decode.append(sublayer))(0, AdainResBlk( + dim_in=dim_out, + dim_out=dim_in, + style_dim=style_dim, + w_hpf=w_hpf, + upsample=_downtype)) # stack-like + dim_in = dim_out + + # bottleneck blocks (encoder) + for _ in range(2): + self.encode.append( + ResBlk(dim_in=dim_out, dim_out=dim_out, normalize=True)) + + # F0 blocks + if F0_channel != 0: + self.decode.insert(0, + AdainResBlk( + dim_in=dim_out + int(F0_channel / 2), + dim_out=dim_out, + style_dim=style_dim, + w_hpf=w_hpf)) + + # bottleneck blocks (decoder) + for _ in range(2): + self.decode.insert(0, + AdainResBlk( + dim_in=dim_out + int(F0_channel / 2), + dim_out=dim_out + int(F0_channel / 2), + style_dim=style_dim, + w_hpf=w_hpf)) + + if F0_channel != 0: + self.F0_conv = nn.Sequential( + ResBlk( + dim_in=F0_channel, + dim_out=int(F0_channel / 2), + normalize=True, + downsample="half"), ) + + if w_hpf > 0: + self.hpf = HighPass(w_hpf) + + def forward(self, + x: paddle.Tensor, + s: paddle.Tensor, + masks: paddle.Tensor=None, + F0: paddle.Tensor=None): + x = self.stem(x) + cache = {} + for block in self.encode: + if (masks is not None) and (x.shape[2] in [32, 64, 128]): + cache[x.shape[2]] = x + x = block(x) + + if F0 is not None: + F0 = self.F0_conv(F0) + F0 = F.adaptive_avg_pool2d(F0, [x.shape[-2], x.shape[-1]]) + x = paddle.concat([x, F0], axis=1) + + for block in self.decode: + x = block(x, s) + if (masks is not None) and (x.shape[2] in [32, 64, 128]): + mask = masks[0] if x.shape[2] in [32] else masks[1] + mask = F.interpolate(mask, size=x.shape[2], mode='bilinear') + x = x + self.hpf(mask * cache[x.shape[2]]) + + return self.to_out(x) + + +class MappingNetwork(nn.Layer): + def __init__(self, + latent_dim: int=16, + style_dim: int=48, + num_domains: int=2, + hidden_dim: int=384): + super().__init__() + layers = [] + layers += [nn.Linear(latent_dim, hidden_dim)] + layers += [nn.ReLU()] + for _ in range(3): + layers += [nn.Linear(hidden_dim, hidden_dim)] + layers += [nn.ReLU()] + self.shared = nn.Sequential(*layers) + + self.unshared = nn.LayerList() + for _ in range(num_domains): + self.unshared.extend([ + nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), nn.Linear(hidden_dim, style_dim)) + ]) + + def forward(self, z: paddle.Tensor, y: paddle.Tensor): + h = self.shared(z) + out = [] + for layer in self.unshared: + out += [layer(h)] + # (batch, num_domains, style_dim) + out = paddle.stack(out, axis=1) + idx = paddle.arange(y.shape[0]) + # (batch, style_dim) + s = out[idx, y] + return s + + +class StyleEncoder(nn.Layer): + def __init__(self, + dim_in: int=48, + style_dim: int=48, + num_domains: int=2, + max_conv_dim: int=384): + super().__init__() + blocks = [] + blocks += [ + nn.Conv2D( + in_channels=1, + out_channels=dim_in, + kernel_size=3, + stride=1, + padding=1) + ] + repeat_num = 4 + for _ in range(repeat_num): + dim_out = min(dim_in * 2, max_conv_dim) + blocks += [ + ResBlk(dim_in=dim_in, dim_out=dim_out, downsample='half') + ] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [ + nn.Conv2D( + in_channels=dim_out, + out_channels=dim_out, + kernel_size=5, + stride=1, + padding=0) + ] + blocks += [nn.AdaptiveAvgPool2D(1)] + blocks += [nn.LeakyReLU(0.2)] + self.shared = nn.Sequential(*blocks) + self.unshared = nn.LayerList() + for _ in range(num_domains): + self.unshared.append(nn.Linear(dim_out, style_dim)) + + def forward(self, x: paddle.Tensor, y: paddle.Tensor): + h = self.shared(x) + h = h.reshape((h.shape[0], -1)) + out = [] + for layer in self.unshared: + out += [layer(h)] + # (batch, num_domains, style_dim) + out = paddle.stack(out, axis=1) + idx = paddle.arange(y.shape[0]) + # (batch, style_dim) + s = out[idx, y] + return s + + +class Discriminator(nn.Layer): + def __init__(self, + dim_in: int=48, + num_domains: int=2, + max_conv_dim: int=384, + repeat_num: int=4): + super().__init__() + # real/fake discriminator + self.dis = Discriminator2D( + dim_in=dim_in, + num_domains=num_domains, + max_conv_dim=max_conv_dim, + repeat_num=repeat_num) + # adversarial classifier + self.cls = Discriminator2D( + dim_in=dim_in, + num_domains=num_domains, + max_conv_dim=max_conv_dim, + repeat_num=repeat_num) + self.num_domains = num_domains + + def forward(self, x: paddle.Tensor, y: paddle.Tensor): + return self.dis(x, y) + + def classifier(self, x: paddle.Tensor): + return self.cls.get_feature(x) + + +class LinearNorm(nn.Layer): + def __init__(self, + in_dim: int, + out_dim: int, + bias: bool=True, + w_init_gain: str='linear'): + super().__init__() + self.linear_layer = nn.Linear(in_dim, out_dim, bias_attr=bias) + xavier_uniform_( + self.linear_layer.weight, gain=_calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class Discriminator2D(nn.Layer): + def __init__(self, + dim_in: int=48, + num_domains: int=2, + max_conv_dim: int=384, + repeat_num: int=4): + super().__init__() + blocks = [] + blocks += [ + nn.Conv2D( + in_channels=1, + out_channels=dim_in, + kernel_size=3, + stride=1, + padding=1) + ] + + for lid in range(repeat_num): + dim_out = min(dim_in * 2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample='half')] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [ + nn.Conv2D( + in_channels=dim_out, + out_channels=dim_out, + kernel_size=5, + stride=1, + padding=0) + ] + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.AdaptiveAvgPool2D(1)] + blocks += [ + nn.Conv2D( + in_channels=dim_out, + out_channels=num_domains, + kernel_size=1, + stride=1, + padding=0) + ] + self.main = nn.Sequential(*blocks) + + def get_feature(self, x: paddle.Tensor): + out = self.main(x) + # (batch, num_domains) + out = out.reshape((out.shape[0], -1)) + return out + + def forward(self, x: paddle.Tensor, y: paddle.Tensor): + out = self.get_feature(x) + idx = paddle.arange(y.shape[0]) + # (batch) + out = out[idx, y] + return out + + +''' +def build_model(args, F0_model: nn.Layer, ASR_model: nn.Layer): + generator = Generator( + dim_in=args.dim_in, + style_dim=args.style_dim, + max_conv_dim=args.max_conv_dim, + w_hpf=args.w_hpf, + F0_channel=args.F0_channel) + mapping_network = MappingNetwork( + latent_dim=args.latent_dim, + style_dim=args.style_dim, + num_domains=args.num_domains, + hidden_dim=args.max_conv_dim) + style_encoder = StyleEncoder( + dim_in=args.dim_in, + style_dim=args.style_dim, + num_domains=args.num_domains, + max_conv_dim=args.max_conv_dim) + discriminator = Discriminator( + dim_in=args.dim_in, + num_domains=args.num_domains, + max_conv_dim=args.max_conv_dim, + n_repeat=args.n_repeat) + generator_ema = copy.deepcopy(generator) + mapping_network_ema = copy.deepcopy(mapping_network) + style_encoder_ema = copy.deepcopy(style_encoder) + + nets = Munch( + generator=generator, + mapping_network=mapping_network, + style_encoder=style_encoder, + discriminator=discriminator, + f0_model=F0_model, + asr_model=ASR_model) + + nets_ema = Munch( + generator=generator_ema, + mapping_network=mapping_network_ema, + style_encoder=style_encoder_ema) + + return nets, nets_ema + + +class StarGANv2VC(nn.Layer): + def __init__( + self, + # spk_num + num_domains: int=20, + dim_in: int=64, + style_dim: int=64, + latent_dim: int=16, + max_conv_dim: int=512, + n_repeat: int=4, + w_hpf: int=0, + F0_channel: int=256): + super().__init__() + + self.generator = Generator( + dim_in=dim_in, + style_dim=style_dim, + max_conv_dim=max_conv_dim, + w_hpf=w_hpf, + F0_channel=F0_channel) + # MappingNetwork and StyleEncoder are used to generate reference_embeddings + self.mapping_network = MappingNetwork( + latent_dim=latent_dim, + style_dim=style_dim, + num_domains=num_domains, + hidden_dim=max_conv_dim) + + self.style_encoder = StyleEncoder( + dim_in=dim_in, + style_dim=style_dim, + num_domains=num_domains, + max_conv_dim=max_conv_dim) + + self.discriminator = Discriminator( + dim_in=dim_in, + num_domains=num_domains, + max_conv_dim=max_conv_dim, + repeat_num=n_repeat) +''' diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py new file mode 100644 index 00000000..595add0a --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.