add multi-speaker support for speedyspeech

pull/1259/head
Jerryuhoo 4 years ago
parent f27d9d50e6
commit 11991b6d35

@ -0,0 +1,52 @@
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # Sampling rate.
n_fft: 2048 # FFT size (samples).
n_shift: 300 # Hop size (samples). 12.5ms
win_length: 1200 # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
n_mels: 80 # Number of mel basis.
fmin: 80 # Minimum freq in mel basis calculation.
fmax: 7600 # Maximum frequency in mel basis calculation.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 32
num_workers: 4
###########################################################
# MODEL SETTING #
###########################################################
model:
encoder_hidden_size: 128
encoder_kernel_size: 3
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
duration_predictor_hidden_size: 128
decoder_hidden_size: 128
decoder_output_size: 80
decoder_kernel_size: 3
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
spk_embed_dim: 256
spk_embed_integration_type: add # speaker embedding integration type
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.002 # learning rate
max_grad_norm: 1
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 100
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

@ -17,7 +17,7 @@ import paddle
from paddlespeech.t2s.data.batch import batch_sequences
def speedyspeech_batch_fn(examples):
def speedyspeech_single_spk_batch_fn(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples):
}
return batch
def speedyspeech_multi_spk_batch_fn(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
durations = [
np.array(item["durations"], dtype=np.int64) for item in examples
]
num_phones = [
np.array(item["num_phones"], dtype=np.int64) for item in examples
]
num_frames = [
np.array(item["num_frames"], dtype=np.int64) for item in examples
]
phones = batch_sequences(phones)
tones = batch_sequences(tones)
feats = batch_sequences(feats)
durations = batch_sequences(durations)
# convert each batch to paddle.Tensor
phones = paddle.to_tensor(phones)
tones = paddle.to_tensor(tones)
feats = paddle.to_tensor(feats)
durations = paddle.to_tensor(durations)
num_phones = paddle.to_tensor(num_phones)
num_frames = paddle.to_tensor(num_frames)
batch = {
"phones": phones,
"tones": tones,
"num_phones": num_phones,
"num_frames": num_frames,
"feats": feats,
"durations": durations,
}
if "spk_id" in examples[0]:
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id)
batch["spk_id"] = spk_id
return batch
def fastspeech2_single_spk_batch_fn(examples):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]

@ -47,7 +47,8 @@ def main():
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--tones-dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--verbose",
type=int,
@ -121,6 +122,12 @@ def main():
for tone, id in tone_id:
vocab_tones[tone] = int(id)
vocab_speaker = {}
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
for spk, id in spk_id:
vocab_speaker[spk] = int(id)
# process each file
output_metadata = []
@ -135,11 +142,13 @@ def main():
np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
phone_ids = [vocab_phones[p] for p in item['phones']]
tone_ids = [vocab_tones[p] for p in item['tones']]
spk_id = vocab_speaker[item["speaker"]]
if args.use_relative_path:
# convert absolute path to relative path:
mel_path = mel_path.relative_to(dumpdir)
output_metadata.append({
'utt_id': utt_id,
"spk_id": spk_id,
'phones': phone_ids,
'tones': tone_ids,
'num_phones': item['num_phones'],

@ -13,6 +13,7 @@
# limitations under the License.
import argparse
import re
import os
from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter
from pathlib import Path
@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import get_phones_tones
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
def process_sentence(config: Dict[str, Any],
fp: Path,
@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any],
"utt_id": utt_id,
"phones": phones,
"tones": tones,
"speaker": speaker,
"num_phones": len(phones),
"num_frames": num_frames,
"durations": durations,
@ -229,6 +231,8 @@ def main():
tone_id_map_path = dumpdir / "tone_id_map.txt"
get_phones_tones(sentences, phone_id_map_path, tone_id_map_path,
args.dataset)
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
get_spk_id_map(speaker_set, speaker_id_map_path)
if args.dataset == "baker":
wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
@ -239,6 +243,28 @@ def main():
dev_wav_files = wav_files[num_train:num_train + num_dev]
test_wav_files = wav_files[num_train + num_dev:]
elif args.dataset == "other":
sub_num_dev = 100
wav_dir = rootdir / "wav"
train_wav_files = []
dev_wav_files = []
test_wav_files = []
for speaker in os.listdir(wav_dir):
if os.path.exists(os.path.join(wav_dir, speaker, "split")):
wav_files = sorted(list((wav_dir / speaker / "split").rglob("*.wav")))
else:
wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
if len(wav_files) > 100:
train_wav_files += wav_files[:-sub_num_dev * 2]
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
test_wav_files += wav_files[-sub_num_dev:]
else:
train_wav_files += wav_files
print("len train_wav_files", len(train_wav_files))
print("len dev_wav_files", len(dev_wav_files))
print("len test_wav_files", len(test_wav_files))
train_dump_dir = dumpdir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True)
dev_dump_dir = dumpdir / "dev" / "raw"

@ -27,7 +27,8 @@ from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_single_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_multi_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.speedyspeech import SpeedySpeech
from paddlespeech.t2s.models.speedyspeech import SpeedySpeechEvaluator
@ -57,6 +58,21 @@ def train_sp(args, config):
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker speedyspeech!")
collate_fn = speedyspeech_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
else:
print("single speaker speedyspeech!")
collate_fn = speedyspeech_single_spk_batch_fn
print("spk_num:", spk_num)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
@ -71,9 +87,7 @@ def train_sp(args, config):
train_dataset = DataTable(
data=train_metadata,
fields=[
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
fields=fields,
converters={
"feats": np.load,
}, )
@ -87,9 +101,7 @@ def train_sp(args, config):
dev_dataset = DataTable(
data=dev_metadata,
fields=[
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
fields=fields,
converters={
"feats": np.load,
}, )
@ -105,14 +117,14 @@ def train_sp(args, config):
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=speedyspeech_batch_fn,
collate_fn=collate_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
shuffle=False,
drop_last=False,
batch_size=config.batch_size,
collate_fn=speedyspeech_batch_fn,
collate_fn=collate_fn,
num_workers=config.num_workers)
print("dataloaders done!")
with open(args.phones_dict, "r") as f:
@ -125,7 +137,7 @@ def train_sp(args, config):
print("tone_size:", tone_size)
model = SpeedySpeech(
vocab_size=vocab_size, tone_size=tone_size, **config["model"])
vocab_size=vocab_size, tone_size=tone_size, spk_num=spk_num, **config["model"])
if world_size > 1:
model = DataParallel(model)
print("model done!")
@ -184,6 +196,12 @@ def main():
parser.add_argument(
"--tones-dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
# 这里可以多传入 max_epoch 等
args, rest = parser.parse_known_args()

@ -14,7 +14,7 @@
import numpy as np
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding
@ -171,7 +171,11 @@ class SpeedySpeech(nn.Layer):
decoder_output_size,
decoder_kernel_size,
decoder_dilations,
tone_size=None, ):
tone_size=None,
spk_num: int=None,
spk_embed_dim: int=None,
spk_embed_integration_type: str="add",
):
super().__init__()
encoder = SpeedySpeechEncoder(vocab_size, tone_size,
encoder_hidden_size, encoder_kernel_size,
@ -183,14 +187,43 @@ class SpeedySpeech(nn.Layer):
self.encoder = encoder
self.duration_predictor = duration_predictor
self.decoder = decoder
def forward(self, text, tones, durations):
self.spk_embed_dim = spk_embed_dim
# use idx 0 as padding idx
self.padding_idx = 0
if self.spk_embed_dim is not None:
self.spk_embed_integration_type = spk_embed_integration_type
if spk_num and self.spk_embed_dim:
self.spk_embedding_table = nn.Embedding(
num_embeddings=spk_num,
embedding_dim=self.spk_embed_dim,
padding_idx=self.padding_idx)
self.encoder_hidden_size = encoder_hidden_size
# define additional projection for speaker embedding
if self.spk_embed_dim is not None:
print("spk_embed_integration_type------------", spk_embed_integration_type)
if self.spk_embed_integration_type == "add":
self.spk_projection = nn.Linear(self.spk_embed_dim, self.encoder_hidden_size)
else:
self.spk_projection = nn.Linear(self.encoder_hidden_size + self.spk_embed_dim, self.encoder_hidden_size)
def forward(self, text, tones, durations, spk_id: paddle.Tensor=None):
# input of embedding must be int64
text = paddle.cast(text, 'int64')
tones = paddle.cast(tones, 'int64')
if spk_id is not None:
spk_id = paddle.cast(spk_id, 'int64')
durations = paddle.cast(durations, 'int64')
encodings = self.encoder(text, tones)
# (B, T)
if self.spk_embed_dim is not None:
if spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
encodings = self._integrate_with_spk_embed(encodings, spk_emb)
pred_durations = self.duration_predictor(encodings.detach())
# expand encodings
@ -204,7 +237,7 @@ class SpeedySpeech(nn.Layer):
decoded = self.decoder(encodings)
return decoded, pred_durations
def inference(self, text, tones=None):
def inference(self, text, tones=None, spk_id=None,):
# text: [T]
# tones: [T]
# input of embedding must be int64
@ -215,6 +248,11 @@ class SpeedySpeech(nn.Layer):
tones = tones.unsqueeze(0)
encodings = self.encoder(text, tones)
if self.spk_embed_dim is not None:
if spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
encodings = self._integrate_with_spk_embed(encodings, spk_emb)
pred_durations = self.duration_predictor(encodings) # (1, T)
durations_to_expand = paddle.round(pred_durations.exp())
durations_to_expand = (durations_to_expand).astype(paddle.int64)
@ -240,6 +278,34 @@ class SpeedySpeech(nn.Layer):
decoded = self.decoder(encodings)
return decoded[0]
def _integrate_with_spk_embed(self, hs, spk_emb):
"""Integrate speaker embedding with hidden states.
Parameters
----------
hs : Tensor
Batch of hidden state sequences (B, Tmax, adim).
spk_emb : Tensor
Batch of speaker embeddings (B, spk_embed_dim).
Returns
----------
Tensor
Batch of integrated hidden state sequences (B, Tmax, adim)
"""
if self.spk_embed_integration_type == "add":
# apply projection and then add to hidden states
spk_emb = self.spk_projection(F.normalize(spk_emb))
hs = hs + spk_emb.unsqueeze(1)
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, hs.shape[1], -1])
hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))
else:
raise NotImplementedError("support only add or concat.")
return hs
class SpeedySpeechInference(nn.Layer):
def __init__(self, normalizer, speedyspeech_model):
@ -247,7 +313,7 @@ class SpeedySpeechInference(nn.Layer):
self.normalizer = normalizer
self.acoustic_model = speedyspeech_model
def forward(self, phones, tones):
normalized_mel = self.acoustic_model.inference(phones, tones)
def forward(self, phones, tones, spk_id=None):
normalized_mel = self.acoustic_model.inference(phones, tones, spk_id)
logmel = self.normalizer.inverse(normalized_mel)
return logmel

@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# spk_id!=None in multiple spk speedyspeech
spk_id = batch["spk_id"] if "spk_id" in batch else None
decoded, predicted_durations = self.model(
text=batch["phones"],
tones=batch["tones"],
durations=batch["durations"])
durations=batch["durations"],
spk_id=spk_id
)
target_mel = batch["feats"]
spec_mask = F.sequence_mask(
@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator):
self.msg = "Evaluate: "
losses_dict = {}
spk_id = batch["spk_id"] if "spk_id" in batch else None
decoded, predicted_durations = self.model(
text=batch["phones"],
tones=batch["tones"],
durations=batch["durations"])
durations=batch["durations"],
spk_id=spk_id
)
target_mel = batch["feats"]
spec_mask = F.sequence_mask(

Loading…
Cancel
Save