[TTS]Fix losses of StarGAN v2 VC (#3184)

pull/3200/head
TianYuan 2 years ago committed by GitHub
parent 84cc5fc98f
commit fc670339d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,4 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1
--ngpu=1 \
--speaker-dict=dump/speaker_id_map.txt

@ -820,12 +820,13 @@ class StarGANv2VCCollateFn:
self.max_mel_length = max_mel_length
def random_clip(self, mel: np.array):
# [80, T]
mel_length = mel.shape[1]
# [T, 80]
mel_length = mel.shape[0]
if mel_length > self.max_mel_length:
random_start = np.random.randint(0,
mel_length - self.max_mel_length)
mel = mel[:, random_start:random_start + self.max_mel_length]
mel = mel[random_start:random_start + self.max_mel_length, :]
return mel
def __call__(self, exmaples):
@ -843,7 +844,6 @@ class StarGANv2VCCollateFn:
mel = [self.random_clip(item["mel"]) for item in examples]
ref_mel = [self.random_clip(item["ref_mel"]) for item in examples]
ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples]
mel = batch_sequences(mel)
ref_mel = batch_sequences(ref_mel)
ref_mel_2 = batch_sequences(ref_mel_2)

@ -113,6 +113,16 @@ def train_sp(args, config):
model_version = '1.0'
uncompress_path = download_and_decompress(StarGANv2VC_source[model_version],
MODEL_HOME)
# 根据 speaker 的个数修改 num_domains
# 源码的预训练模型和 default.yaml 里面默认是 20
if args.speaker_dict is not None:
with open(args.speaker_dict, 'rt', encoding='utf-8') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
config['mapping_network_params']['num_domains'] = spk_num
config['style_encoder_params']['num_domains'] = spk_num
config['discriminator_params']['num_domains'] = spk_num
generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params'])
@ -123,7 +133,7 @@ def train_sp(args, config):
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
asr_model_dir = os.path.join(uncompress_path, 'asr.pdz')
F0_model = JDCNet(num_class=1, seq_len=192)
F0_model = JDCNet(num_class=1, seq_len=config['max_mel_length'])
F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params'])
F0_model.eval()
@ -234,6 +244,11 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
args = parser.parse_args()

@ -19,35 +19,38 @@ import paddle.nn.functional as F
from .transforms import build_transforms
# 这些都写到 updater 里
def compute_d_loss(nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
use_r1_reg: bool=True,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):
def compute_d_loss(
nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
# TODO: should be True here, but r1_reg has some bug now
use_r1_reg: bool=False,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):
assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg:
loss_reg = r1_reg(out, x_real)
else:
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
# loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_reg = paddle.zeros([1])
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_con_reg = paddle.zeros([1])
if use_con_reg:
t = build_transforms()
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
@ -118,9 +121,10 @@ def compute_g_loss(nets: Dict[str, Any],
s_trg = nets['style_encoder'](x_ref, y_trg)
# compute ASR/F0 features (real)
with paddle.no_grad():
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)
# 源码没有用 .eval(), 使用了 no_grad()
# 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)
# adversarial loss
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)

@ -259,7 +259,7 @@ class StarGANv2VCEvaluator(StandardEvaluator):
y_org=y_org,
y_trg=y_trg,
z_trg=z_trg,
use_r1_reg=False,
use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls,
**self.d_loss_params)
@ -269,7 +269,7 @@ class StarGANv2VCEvaluator(StandardEvaluator):
y_org=y_org,
y_trg=y_trg,
x_ref=x_ref,
use_r1_reg=False,
use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls,
**self.d_loss_params)

Loading…
Cancel
Save