[TTS]Fix losses of StarGAN v2 VC (#3184)

pull/3200/head
TianYuan 2 years ago committed by GitHub
parent 84cc5fc98f
commit fc670339d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,4 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=1 --ngpu=1 \
--speaker-dict=dump/speaker_id_map.txt

@ -820,12 +820,13 @@ class StarGANv2VCCollateFn:
self.max_mel_length = max_mel_length self.max_mel_length = max_mel_length
def random_clip(self, mel: np.array): def random_clip(self, mel: np.array):
# [80, T] # [T, 80]
mel_length = mel.shape[1] mel_length = mel.shape[0]
if mel_length > self.max_mel_length: if mel_length > self.max_mel_length:
random_start = np.random.randint(0, random_start = np.random.randint(0,
mel_length - self.max_mel_length) mel_length - self.max_mel_length)
mel = mel[:, random_start:random_start + self.max_mel_length]
mel = mel[random_start:random_start + self.max_mel_length, :]
return mel return mel
def __call__(self, exmaples): def __call__(self, exmaples):
@ -843,7 +844,6 @@ class StarGANv2VCCollateFn:
mel = [self.random_clip(item["mel"]) for item in examples] mel = [self.random_clip(item["mel"]) for item in examples]
ref_mel = [self.random_clip(item["ref_mel"]) for item in examples] ref_mel = [self.random_clip(item["ref_mel"]) for item in examples]
ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples] ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples]
mel = batch_sequences(mel) mel = batch_sequences(mel)
ref_mel = batch_sequences(ref_mel) ref_mel = batch_sequences(ref_mel)
ref_mel_2 = batch_sequences(ref_mel_2) ref_mel_2 = batch_sequences(ref_mel_2)

@ -113,6 +113,16 @@ def train_sp(args, config):
model_version = '1.0' model_version = '1.0'
uncompress_path = download_and_decompress(StarGANv2VC_source[model_version], uncompress_path = download_and_decompress(StarGANv2VC_source[model_version],
MODEL_HOME) MODEL_HOME)
# 根据 speaker 的个数修改 num_domains
# 源码的预训练模型和 default.yaml 里面默认是 20
if args.speaker_dict is not None:
with open(args.speaker_dict, 'rt', encoding='utf-8') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
config['mapping_network_params']['num_domains'] = spk_num
config['style_encoder_params']['num_domains'] = spk_num
config['discriminator_params']['num_domains'] = spk_num
generator = Generator(**config['generator_params']) generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params']) mapping_network = MappingNetwork(**config['mapping_network_params'])
@ -123,7 +133,7 @@ def train_sp(args, config):
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
asr_model_dir = os.path.join(uncompress_path, 'asr.pdz') asr_model_dir = os.path.join(uncompress_path, 'asr.pdz')
F0_model = JDCNet(num_class=1, seq_len=192) F0_model = JDCNet(num_class=1, seq_len=config['max_mel_length'])
F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params']) F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params'])
F0_model.eval() F0_model.eval()
@ -234,6 +244,11 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
args = parser.parse_args() args = parser.parse_args()

@ -19,15 +19,18 @@ import paddle.nn.functional as F
from .transforms import build_transforms from .transforms import build_transforms
# 这些都写到 updater 里 # 这些都写到 updater 里
def compute_d_loss(nets: Dict[str, Any],
def compute_d_loss(
nets: Dict[str, Any],
x_real: paddle.Tensor, x_real: paddle.Tensor,
y_org: paddle.Tensor, y_org: paddle.Tensor,
y_trg: paddle.Tensor, y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None, z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None, x_ref: paddle.Tensor=None,
use_r1_reg: bool=True, # TODO: should be True here, but r1_reg has some bug now
use_r1_reg: bool=False,
use_adv_cls: bool=False, use_adv_cls: bool=False,
use_con_reg: bool=False, use_con_reg: bool=False,
lambda_reg: float=1., lambda_reg: float=1.,
@ -39,15 +42,15 @@ def compute_d_loss(nets: Dict[str, Any],
x_real.stop_gradient = False x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org) out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1) loss_real = adv_loss(out, 1)
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4) # R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg: if use_r1_reg:
loss_reg = r1_reg(out, x_real) loss_reg = r1_reg(out, x_real)
else: else:
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32) # loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_reg = paddle.zeros([1])
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724) # consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32) loss_con_reg = paddle.zeros([1])
if use_con_reg: if use_con_reg:
t = build_transforms() t = build_transforms()
out_aug = nets['discriminator'](t(x_real).detach(), y_org) out_aug = nets['discriminator'](t(x_real).detach(), y_org)
@ -118,7 +121,8 @@ def compute_g_loss(nets: Dict[str, Any],
s_trg = nets['style_encoder'](x_ref, y_trg) s_trg = nets['style_encoder'](x_ref, y_trg)
# compute ASR/F0 features (real) # compute ASR/F0 features (real)
with paddle.no_grad(): # 源码没有用 .eval(), 使用了 no_grad()
# 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real) ASR_real = nets['asr_model'].get_feature(x_real)

@ -259,7 +259,7 @@ class StarGANv2VCEvaluator(StandardEvaluator):
y_org=y_org, y_org=y_org,
y_trg=y_trg, y_trg=y_trg,
z_trg=z_trg, z_trg=z_trg,
use_r1_reg=False, use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls, use_adv_cls=use_adv_cls,
**self.d_loss_params) **self.d_loss_params)
@ -269,7 +269,7 @@ class StarGANv2VCEvaluator(StandardEvaluator):
y_org=y_org, y_org=y_org,
y_trg=y_trg, y_trg=y_trg,
x_ref=x_ref, x_ref=x_ref,
use_r1_reg=False, use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls, use_adv_cls=use_adv_cls,
**self.d_loss_params) **self.d_loss_params)

Loading…
Cancel
Save