rename num_speakers

pull/1003/head
TianYuan 4 years ago
parent a97c7b5206
commit 133ee7db0b

@ -3,7 +3,7 @@
set -e set -e
source path.sh source path.sh
gpus=4,5 gpus=0,1
stage=0 stage=0
stop_stage=100 stop_stage=100

@ -46,14 +46,14 @@ def evaluate(args, fastspeech2_config, pwg_config):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id) spk_num = len(spk_id)
print("num_speakers:", num_speakers) print("spk_num:", spk_num)
odim = fastspeech2_config.n_mels odim = fastspeech2_config.n_mels
model = FastSpeech2( model = FastSpeech2(
idim=vocab_size, idim=vocab_size,
odim=odim, odim=odim,
num_speakers=num_speakers, spk_num=spk_num,
**fastspeech2_config["model"]) **fastspeech2_config["model"])
model.set_state_dict( model.set_state_dict(

@ -51,14 +51,14 @@ def evaluate(args, fastspeech2_config, pwg_config):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id) spk_num = len(spk_id)
print("num_speakers:", num_speakers) print("spk_num:", spk_num)
odim = fastspeech2_config.n_mels odim = fastspeech2_config.n_mels
model = FastSpeech2( model = FastSpeech2(
idim=vocab_size, idim=vocab_size,
odim=odim, odim=odim,
num_speakers=num_speakers, spk_num=spk_num,
**fastspeech2_config["model"]) **fastspeech2_config["model"])
model.set_state_dict( model.set_state_dict(

@ -40,19 +40,19 @@ def evaluate(args, fastspeech2_config, pwg_config):
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
num_speakers = None spk_num = None
if args.speaker_dict is not None: if args.speaker_dict is not None:
print("multiple speaker fastspeech2!") print("multiple speaker fastspeech2!")
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id) spk_num = len(spk_id)
fields += ["spk_id"] fields += ["spk_id"]
elif args.voice_cloning: elif args.voice_cloning:
print("voice cloning!") print("voice cloning!")
fields += ["spk_emb"] fields += ["spk_emb"]
else: else:
print("single speaker fastspeech2!") print("single speaker fastspeech2!")
print("num_speakers:", num_speakers) print("spk_num:", spk_num)
test_dataset = DataTable(data=test_metadata, fields=fields) test_dataset = DataTable(data=test_metadata, fields=fields)
@ -65,7 +65,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
model = FastSpeech2( model = FastSpeech2(
idim=vocab_size, idim=vocab_size,
odim=odim, odim=odim,
num_speakers=num_speakers, spk_num=spk_num,
**fastspeech2_config["model"]) **fastspeech2_config["model"])
model.set_state_dict( model.set_state_dict(

@ -62,13 +62,13 @@ def train_sp(args, config):
"pitch", "energy" "pitch", "energy"
] ]
converters = {"speech": np.load, "pitch": np.load, "energy": np.load} converters = {"speech": np.load, "pitch": np.load, "energy": np.load}
num_speakers = None spk_num = None
if args.speaker_dict is not None: if args.speaker_dict is not None:
print("multiple speaker fastspeech2!") print("multiple speaker fastspeech2!")
collate_fn = fastspeech2_multi_spk_batch_fn collate_fn = fastspeech2_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id) spk_num = len(spk_id)
fields += ["spk_id"] fields += ["spk_id"]
elif args.voice_cloning: elif args.voice_cloning:
print("Training voice cloning!") print("Training voice cloning!")
@ -78,7 +78,7 @@ def train_sp(args, config):
else: else:
print("single speaker fastspeech2!") print("single speaker fastspeech2!")
collate_fn = fastspeech2_single_spk_batch_fn collate_fn = fastspeech2_single_spk_batch_fn
print("num_speakers:", num_speakers) print("spk_num:", spk_num)
# dataloader has been too verbose # dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True logging.getLogger("DataLoader").disabled = True
@ -129,10 +129,7 @@ def train_sp(args, config):
odim = config.n_mels odim = config.n_mels
model = FastSpeech2( model = FastSpeech2(
idim=vocab_size, idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"])
odim=odim,
num_speakers=num_speakers,
**config["model"])
if world_size > 1: if world_size > 1:
model = DataParallel(model) model = DataParallel(model)
print("model done!") print("model done!")

@ -96,7 +96,7 @@ class FastSpeech2(nn.Layer):
pitch_embed_dropout: float=0.5, pitch_embed_dropout: float=0.5,
stop_gradient_from_pitch_predictor: bool=False, stop_gradient_from_pitch_predictor: bool=False,
# spk emb # spk emb
num_speakers: int=None, spk_num: int=None,
spk_embed_dim: int=None, spk_embed_dim: int=None,
spk_embed_integration_type: str="add", spk_embed_integration_type: str="add",
# tone emb # tone emb
@ -146,9 +146,9 @@ class FastSpeech2(nn.Layer):
# initialize parameters # initialize parameters
initialize(self, init_type) initialize(self, init_type)
if self.spk_embed_dim and num_speakers: if spk_num and self.spk_embed_dim:
self.spk_embedding_table = nn.Embedding( self.spk_embedding_table = nn.Embedding(
num_embeddings=num_speakers, num_embeddings=spk_num,
embedding_dim=self.spk_embed_dim, embedding_dim=self.spk_embed_dim,
padding_idx=self.padding_idx) padding_idx=self.padding_idx)

Loading…
Cancel
Save