add multi-speaker support for speedyspeech

pull/1259/head
Jerryuhoo 4 years ago
parent f27d9d50e6
commit 11991b6d35

@ -0,0 +1,52 @@
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # Sampling rate.
n_fft: 2048 # FFT size (samples).
n_shift: 300 # Hop size (samples). 12.5ms
win_length: 1200 # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
n_mels: 80 # Number of mel basis.
fmin: 80 # Minimum freq in mel basis calculation.
fmax: 7600 # Maximum frequency in mel basis calculation.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 32
num_workers: 4
###########################################################
# MODEL SETTING #
###########################################################
model:
encoder_hidden_size: 128
encoder_kernel_size: 3
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
duration_predictor_hidden_size: 128
decoder_hidden_size: 128
decoder_output_size: 80
decoder_kernel_size: 3
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
spk_embed_dim: 256
spk_embed_integration_type: add # speaker embedding integration type
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.002 # learning rate
max_grad_norm: 1
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 100
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

@ -17,7 +17,7 @@ import paddle
from paddlespeech.t2s.data.batch import batch_sequences from paddlespeech.t2s.data.batch import batch_sequences
def speedyspeech_batch_fn(examples): def speedyspeech_single_spk_batch_fn(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"] # fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples] phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples] tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples):
} }
return batch return batch
def speedyspeech_multi_spk_batch_fn(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
durations = [
np.array(item["durations"], dtype=np.int64) for item in examples
]
num_phones = [
np.array(item["num_phones"], dtype=np.int64) for item in examples
]
num_frames = [
np.array(item["num_frames"], dtype=np.int64) for item in examples
]
phones = batch_sequences(phones)
tones = batch_sequences(tones)
feats = batch_sequences(feats)
durations = batch_sequences(durations)
# convert each batch to paddle.Tensor
phones = paddle.to_tensor(phones)
tones = paddle.to_tensor(tones)
feats = paddle.to_tensor(feats)
durations = paddle.to_tensor(durations)
num_phones = paddle.to_tensor(num_phones)
num_frames = paddle.to_tensor(num_frames)
batch = {
"phones": phones,
"tones": tones,
"num_phones": num_phones,
"num_frames": num_frames,
"feats": feats,
"durations": durations,
}
if "spk_id" in examples[0]:
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id)
batch["spk_id"] = spk_id
return batch
def fastspeech2_single_spk_batch_fn(examples): def fastspeech2_single_spk_batch_fn(examples):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"] # fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]

@ -47,7 +47,8 @@ def main():
"--phones-dict", type=str, default=None, help="phone vocabulary file.") "--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument( parser.add_argument(
"--tones-dict", type=str, default=None, help="tone vocabulary file.") "--tones-dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
type=int, type=int,
@ -121,6 +122,12 @@ def main():
for tone, id in tone_id: for tone, id in tone_id:
vocab_tones[tone] = int(id) vocab_tones[tone] = int(id)
vocab_speaker = {}
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
for spk, id in spk_id:
vocab_speaker[spk] = int(id)
# process each file # process each file
output_metadata = [] output_metadata = []
@ -135,11 +142,13 @@ def main():
np.save(mel_path, mel.astype(np.float32), allow_pickle=False) np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
phone_ids = [vocab_phones[p] for p in item['phones']] phone_ids = [vocab_phones[p] for p in item['phones']]
tone_ids = [vocab_tones[p] for p in item['tones']] tone_ids = [vocab_tones[p] for p in item['tones']]
spk_id = vocab_speaker[item["speaker"]]
if args.use_relative_path: if args.use_relative_path:
# convert absolute path to relative path: # convert absolute path to relative path:
mel_path = mel_path.relative_to(dumpdir) mel_path = mel_path.relative_to(dumpdir)
output_metadata.append({ output_metadata.append({
'utt_id': utt_id, 'utt_id': utt_id,
"spk_id": spk_id,
'phones': phone_ids, 'phones': phone_ids,
'tones': tone_ids, 'tones': tone_ids,
'num_phones': item['num_phones'], 'num_phones': item['num_phones'],

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import re import re
import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter from operator import itemgetter
from pathlib import Path from pathlib import Path
@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import get_phones_tones from paddlespeech.t2s.datasets.preprocess_utils import get_phones_tones
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
def process_sentence(config: Dict[str, Any], def process_sentence(config: Dict[str, Any],
fp: Path, fp: Path,
@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any],
"utt_id": utt_id, "utt_id": utt_id,
"phones": phones, "phones": phones,
"tones": tones, "tones": tones,
"speaker": speaker,
"num_phones": len(phones), "num_phones": len(phones),
"num_frames": num_frames, "num_frames": num_frames,
"durations": durations, "durations": durations,
@ -229,6 +231,8 @@ def main():
tone_id_map_path = dumpdir / "tone_id_map.txt" tone_id_map_path = dumpdir / "tone_id_map.txt"
get_phones_tones(sentences, phone_id_map_path, tone_id_map_path, get_phones_tones(sentences, phone_id_map_path, tone_id_map_path,
args.dataset) args.dataset)
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
get_spk_id_map(speaker_set, speaker_id_map_path)
if args.dataset == "baker": if args.dataset == "baker":
wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
@ -239,6 +243,28 @@ def main():
dev_wav_files = wav_files[num_train:num_train + num_dev] dev_wav_files = wav_files[num_train:num_train + num_dev]
test_wav_files = wav_files[num_train + num_dev:] test_wav_files = wav_files[num_train + num_dev:]
elif args.dataset == "other":
sub_num_dev = 100
wav_dir = rootdir / "wav"
train_wav_files = []
dev_wav_files = []
test_wav_files = []
for speaker in os.listdir(wav_dir):
if os.path.exists(os.path.join(wav_dir, speaker, "split")):
wav_files = sorted(list((wav_dir / speaker / "split").rglob("*.wav")))
else:
wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
if len(wav_files) > 100:
train_wav_files += wav_files[:-sub_num_dev * 2]
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
test_wav_files += wav_files[-sub_num_dev:]
else:
train_wav_files += wav_files
print("len train_wav_files", len(train_wav_files))
print("len dev_wav_files", len(dev_wav_files))
print("len test_wav_files", len(test_wav_files))
train_dump_dir = dumpdir / "train" / "raw" train_dump_dir = dumpdir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True) train_dump_dir.mkdir(parents=True, exist_ok=True)
dev_dump_dir = dumpdir / "dev" / "raw" dev_dump_dir = dumpdir / "dev" / "raw"

@ -27,7 +27,8 @@ from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_single_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_multi_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.speedyspeech import SpeedySpeech from paddlespeech.t2s.models.speedyspeech import SpeedySpeech
from paddlespeech.t2s.models.speedyspeech import SpeedySpeechEvaluator from paddlespeech.t2s.models.speedyspeech import SpeedySpeechEvaluator
@ -57,6 +58,21 @@ def train_sp(args, config):
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
) )
fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker speedyspeech!")
collate_fn = speedyspeech_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
else:
print("single speaker speedyspeech!")
collate_fn = speedyspeech_single_spk_batch_fn
print("spk_num:", spk_num)
# dataloader has been too verbose # dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True logging.getLogger("DataLoader").disabled = True
@ -71,9 +87,7 @@ def train_sp(args, config):
train_dataset = DataTable( train_dataset = DataTable(
data=train_metadata, data=train_metadata,
fields=[ fields=fields,
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
converters={ converters={
"feats": np.load, "feats": np.load,
}, ) }, )
@ -87,9 +101,7 @@ def train_sp(args, config):
dev_dataset = DataTable( dev_dataset = DataTable(
data=dev_metadata, data=dev_metadata,
fields=[ fields=fields,
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
converters={ converters={
"feats": np.load, "feats": np.load,
}, ) }, )
@ -105,14 +117,14 @@ def train_sp(args, config):
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
collate_fn=speedyspeech_batch_fn, collate_fn=collate_fn,
num_workers=config.num_workers) num_workers=config.num_workers)
dev_dataloader = DataLoader( dev_dataloader = DataLoader(
dev_dataset, dev_dataset,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
batch_size=config.batch_size, batch_size=config.batch_size,
collate_fn=speedyspeech_batch_fn, collate_fn=collate_fn,
num_workers=config.num_workers) num_workers=config.num_workers)
print("dataloaders done!") print("dataloaders done!")
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
@ -125,7 +137,7 @@ def train_sp(args, config):
print("tone_size:", tone_size) print("tone_size:", tone_size)
model = SpeedySpeech( model = SpeedySpeech(
vocab_size=vocab_size, tone_size=tone_size, **config["model"]) vocab_size=vocab_size, tone_size=tone_size, spk_num=spk_num, **config["model"])
if world_size > 1: if world_size > 1:
model = DataParallel(model) model = DataParallel(model)
print("model done!") print("model done!")
@ -184,6 +196,12 @@ def main():
parser.add_argument( parser.add_argument(
"--tones-dict", type=str, default=None, help="tone vocabulary file.") "--tones-dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
# 这里可以多传入 max_epoch 等 # 这里可以多传入 max_epoch 等
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()

@ -14,7 +14,7 @@
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
import paddle.nn.functional as F
from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding
@ -171,7 +171,11 @@ class SpeedySpeech(nn.Layer):
decoder_output_size, decoder_output_size,
decoder_kernel_size, decoder_kernel_size,
decoder_dilations, decoder_dilations,
tone_size=None, ): tone_size=None,
spk_num: int=None,
spk_embed_dim: int=None,
spk_embed_integration_type: str="add",
):
super().__init__() super().__init__()
encoder = SpeedySpeechEncoder(vocab_size, tone_size, encoder = SpeedySpeechEncoder(vocab_size, tone_size,
encoder_hidden_size, encoder_kernel_size, encoder_hidden_size, encoder_kernel_size,
@ -183,14 +187,43 @@ class SpeedySpeech(nn.Layer):
self.encoder = encoder self.encoder = encoder
self.duration_predictor = duration_predictor self.duration_predictor = duration_predictor
self.decoder = decoder self.decoder = decoder
self.spk_embed_dim = spk_embed_dim
def forward(self, text, tones, durations): # use idx 0 as padding idx
self.padding_idx = 0
if self.spk_embed_dim is not None:
self.spk_embed_integration_type = spk_embed_integration_type
if spk_num and self.spk_embed_dim:
self.spk_embedding_table = nn.Embedding(
num_embeddings=spk_num,
embedding_dim=self.spk_embed_dim,
padding_idx=self.padding_idx)
self.encoder_hidden_size = encoder_hidden_size
# define additional projection for speaker embedding
if self.spk_embed_dim is not None:
print("spk_embed_integration_type------------", spk_embed_integration_type)
if self.spk_embed_integration_type == "add":
self.spk_projection = nn.Linear(self.spk_embed_dim, self.encoder_hidden_size)
else:
self.spk_projection = nn.Linear(self.encoder_hidden_size + self.spk_embed_dim, self.encoder_hidden_size)
def forward(self, text, tones, durations, spk_id: paddle.Tensor=None):
# input of embedding must be int64 # input of embedding must be int64
text = paddle.cast(text, 'int64') text = paddle.cast(text, 'int64')
tones = paddle.cast(tones, 'int64') tones = paddle.cast(tones, 'int64')
if spk_id is not None:
spk_id = paddle.cast(spk_id, 'int64')
durations = paddle.cast(durations, 'int64') durations = paddle.cast(durations, 'int64')
encodings = self.encoder(text, tones) encodings = self.encoder(text, tones)
# (B, T) # (B, T)
if self.spk_embed_dim is not None:
if spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
encodings = self._integrate_with_spk_embed(encodings, spk_emb)
pred_durations = self.duration_predictor(encodings.detach()) pred_durations = self.duration_predictor(encodings.detach())
# expand encodings # expand encodings
@ -204,7 +237,7 @@ class SpeedySpeech(nn.Layer):
decoded = self.decoder(encodings) decoded = self.decoder(encodings)
return decoded, pred_durations return decoded, pred_durations
def inference(self, text, tones=None): def inference(self, text, tones=None, spk_id=None,):
# text: [T] # text: [T]
# tones: [T] # tones: [T]
# input of embedding must be int64 # input of embedding must be int64
@ -215,6 +248,11 @@ class SpeedySpeech(nn.Layer):
tones = tones.unsqueeze(0) tones = tones.unsqueeze(0)
encodings = self.encoder(text, tones) encodings = self.encoder(text, tones)
if self.spk_embed_dim is not None:
if spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
encodings = self._integrate_with_spk_embed(encodings, spk_emb)
pred_durations = self.duration_predictor(encodings) # (1, T) pred_durations = self.duration_predictor(encodings) # (1, T)
durations_to_expand = paddle.round(pred_durations.exp()) durations_to_expand = paddle.round(pred_durations.exp())
durations_to_expand = (durations_to_expand).astype(paddle.int64) durations_to_expand = (durations_to_expand).astype(paddle.int64)
@ -240,6 +278,34 @@ class SpeedySpeech(nn.Layer):
decoded = self.decoder(encodings) decoded = self.decoder(encodings)
return decoded[0] return decoded[0]
def _integrate_with_spk_embed(self, hs, spk_emb):
"""Integrate speaker embedding with hidden states.
Parameters
----------
hs : Tensor
Batch of hidden state sequences (B, Tmax, adim).
spk_emb : Tensor
Batch of speaker embeddings (B, spk_embed_dim).
Returns
----------
Tensor
Batch of integrated hidden state sequences (B, Tmax, adim)
"""
if self.spk_embed_integration_type == "add":
# apply projection and then add to hidden states
spk_emb = self.spk_projection(F.normalize(spk_emb))
hs = hs + spk_emb.unsqueeze(1)
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, hs.shape[1], -1])
hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))
else:
raise NotImplementedError("support only add or concat.")
return hs
class SpeedySpeechInference(nn.Layer): class SpeedySpeechInference(nn.Layer):
def __init__(self, normalizer, speedyspeech_model): def __init__(self, normalizer, speedyspeech_model):
@ -247,7 +313,7 @@ class SpeedySpeechInference(nn.Layer):
self.normalizer = normalizer self.normalizer = normalizer
self.acoustic_model = speedyspeech_model self.acoustic_model = speedyspeech_model
def forward(self, phones, tones): def forward(self, phones, tones, spk_id=None):
normalized_mel = self.acoustic_model.inference(phones, tones) normalized_mel = self.acoustic_model.inference(phones, tones, spk_id)
logmel = self.normalizer.inverse(normalized_mel) logmel = self.normalizer.inverse(normalized_mel)
return logmel return logmel

@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater):
self.msg = "Rank: {}, ".format(dist.get_rank()) self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {} losses_dict = {}
# spk_id!=None in multiple spk speedyspeech
spk_id = batch["spk_id"] if "spk_id" in batch else None
decoded, predicted_durations = self.model( decoded, predicted_durations = self.model(
text=batch["phones"], text=batch["phones"],
tones=batch["tones"], tones=batch["tones"],
durations=batch["durations"]) durations=batch["durations"],
spk_id=spk_id
)
target_mel = batch["feats"] target_mel = batch["feats"]
spec_mask = F.sequence_mask( spec_mask = F.sequence_mask(
@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator):
self.msg = "Evaluate: " self.msg = "Evaluate: "
losses_dict = {} losses_dict = {}
spk_id = batch["spk_id"] if "spk_id" in batch else None
decoded, predicted_durations = self.model( decoded, predicted_durations = self.model(
text=batch["phones"], text=batch["phones"],
tones=batch["tones"], tones=batch["tones"],
durations=batch["durations"]) durations=batch["durations"],
spk_id=spk_id
)
target_mel = batch["feats"] target_mel = batch["feats"]
spec_mask = F.sequence_mask( spec_mask = F.sequence_mask(

Loading…
Cancel
Save