diff --git a/examples/csmsc/tts2/local/preprocess.sh b/examples/csmsc/tts2/local/preprocess.sh index f7f5ea74..c44f075d 100755 --- a/examples/csmsc/tts2/local/preprocess.sh +++ b/examples/csmsc/tts2/local/preprocess.sh @@ -45,6 +45,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --stats=dump/train/feats_stats.npy \ --phones-dict=dump/phone_id_map.txt \ --tones-dict=dump/tone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ --use-relative-path=True python3 ${BIN_DIR}/normalize.py \ @@ -53,6 +54,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --stats=dump/train/feats_stats.npy \ --phones-dict=dump/phone_id_map.txt \ --tones-dict=dump/tone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ --use-relative-path=True python3 ${BIN_DIR}/normalize.py \ @@ -61,6 +63,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --stats=dump/train/feats_stats.npy \ --phones-dict=dump/phone_id_map.txt \ --tones-dict=dump/tone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ --use-relative-path=True fi diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 9470f923..2d772bf3 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -17,7 +17,7 @@ import paddle from paddlespeech.t2s.data.batch import batch_sequences -def speedyspeech_batch_fn(examples): +def speedyspeech_single_spk_batch_fn(examples): # fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"] phones = [np.array(item["phones"], dtype=np.int64) for item in examples] tones = [np.array(item["tones"], dtype=np.int64) for item in examples] @@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples): } return batch +def speedyspeech_multi_spk_batch_fn(examples): + # fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"] + phones = [np.array(item["phones"], dtype=np.int64) for item in examples] + tones = [np.array(item["tones"], dtype=np.int64) for item in examples] + feats = [np.array(item["feats"], dtype=np.float32) for item in examples] + durations = [ + np.array(item["durations"], dtype=np.int64) for item in examples + ] + num_phones = [ + np.array(item["num_phones"], dtype=np.int64) for item in examples + ] + num_frames = [ + np.array(item["num_frames"], dtype=np.int64) for item in examples + ] + + phones = batch_sequences(phones) + tones = batch_sequences(tones) + feats = batch_sequences(feats) + durations = batch_sequences(durations) + + # convert each batch to paddle.Tensor + phones = paddle.to_tensor(phones) + tones = paddle.to_tensor(tones) + feats = paddle.to_tensor(feats) + durations = paddle.to_tensor(durations) + num_phones = paddle.to_tensor(num_phones) + num_frames = paddle.to_tensor(num_frames) + batch = { + "phones": phones, + "tones": tones, + "num_phones": num_phones, + "num_frames": num_frames, + "feats": feats, + "durations": durations, + } + if "spk_id" in examples[0]: + spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples] + spk_id = paddle.to_tensor(spk_id) + batch["spk_id"] = spk_id + return batch def fastspeech2_single_spk_batch_fn(examples): # fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"] diff --git a/paddlespeech/t2s/exps/speedyspeech/normalize.py b/paddlespeech/t2s/exps/speedyspeech/normalize.py index 91d15c40..a427c469 100644 --- a/paddlespeech/t2s/exps/speedyspeech/normalize.py +++ b/paddlespeech/t2s/exps/speedyspeech/normalize.py @@ -47,7 +47,8 @@ def main(): "--phones-dict", type=str, default=None, help="phone vocabulary file.") parser.add_argument( "--tones-dict", type=str, default=None, help="tone vocabulary file.") - + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") parser.add_argument( "--verbose", type=int, @@ -121,6 +122,12 @@ def main(): for tone, id in tone_id: vocab_tones[tone] = int(id) + vocab_speaker = {} + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + # process each file output_metadata = [] @@ -135,11 +142,13 @@ def main(): np.save(mel_path, mel.astype(np.float32), allow_pickle=False) phone_ids = [vocab_phones[p] for p in item['phones']] tone_ids = [vocab_tones[p] for p in item['tones']] + spk_id = vocab_speaker[item["speaker"]] if args.use_relative_path: # convert absolute path to relative path: mel_path = mel_path.relative_to(dumpdir) output_metadata.append({ 'utt_id': utt_id, + "spk_id": spk_id, 'phones': phone_ids, 'tones': tone_ids, 'num_phones': item['num_phones'], diff --git a/paddlespeech/t2s/exps/speedyspeech/preprocess.py b/paddlespeech/t2s/exps/speedyspeech/preprocess.py index aa589d5a..6003d140 100644 --- a/paddlespeech/t2s/exps/speedyspeech/preprocess.py +++ b/paddlespeech/t2s/exps/speedyspeech/preprocess.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse import re +import os from concurrent.futures import ThreadPoolExecutor from operator import itemgetter from pathlib import Path @@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_ from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import get_phones_tones from paddlespeech.t2s.datasets.preprocess_utils import merge_silence - +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map def process_sentence(config: Dict[str, Any], fp: Path, @@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any], "utt_id": utt_id, "phones": phones, "tones": tones, + "speaker": speaker, "num_phones": len(phones), "num_frames": num_frames, "durations": durations, @@ -229,6 +231,8 @@ def main(): tone_id_map_path = dumpdir / "tone_id_map.txt" get_phones_tones(sentences, phone_id_map_path, tone_id_map_path, args.dataset) + speaker_id_map_path = dumpdir / "speaker_id_map.txt" + get_spk_id_map(speaker_set, speaker_id_map_path) if args.dataset == "baker": wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) diff --git a/paddlespeech/t2s/exps/speedyspeech/train.py b/paddlespeech/t2s/exps/speedyspeech/train.py index aaa71b64..cf3741a0 100644 --- a/paddlespeech/t2s/exps/speedyspeech/train.py +++ b/paddlespeech/t2s/exps/speedyspeech/train.py @@ -27,7 +27,8 @@ from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from yacs.config import CfgNode -from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_batch_fn +from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_single_spk_batch_fn +from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_multi_spk_batch_fn from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.models.speedyspeech import SpeedySpeech from paddlespeech.t2s.models.speedyspeech import SpeedySpeechEvaluator @@ -57,6 +58,21 @@ def train_sp(args, config): f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", ) + fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"] + + spk_num = None + if args.speaker_dict is not None: + print("multiple speaker speedyspeech!") + collate_fn = speedyspeech_multi_spk_batch_fn + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + fields += ["spk_id"] + else: + print("single speaker speedyspeech!") + collate_fn = speedyspeech_single_spk_batch_fn + print("spk_num:", spk_num) + # dataloader has been too verbose logging.getLogger("DataLoader").disabled = True @@ -71,9 +87,7 @@ def train_sp(args, config): train_dataset = DataTable( data=train_metadata, - fields=[ - "phones", "tones", "num_phones", "num_frames", "feats", "durations" - ], + fields=fields, converters={ "feats": np.load, }, ) @@ -87,9 +101,7 @@ def train_sp(args, config): dev_dataset = DataTable( data=dev_metadata, - fields=[ - "phones", "tones", "num_phones", "num_frames", "feats", "durations" - ], + fields=fields, converters={ "feats": np.load, }, ) @@ -105,14 +117,14 @@ def train_sp(args, config): train_dataloader = DataLoader( train_dataset, batch_sampler=train_sampler, - collate_fn=speedyspeech_batch_fn, + collate_fn=collate_fn, num_workers=config.num_workers) dev_dataloader = DataLoader( dev_dataset, shuffle=False, drop_last=False, batch_size=config.batch_size, - collate_fn=speedyspeech_batch_fn, + collate_fn=collate_fn, num_workers=config.num_workers) print("dataloaders done!") with open(args.phones_dict, "r") as f: @@ -125,7 +137,7 @@ def train_sp(args, config): print("tone_size:", tone_size) model = SpeedySpeech( - vocab_size=vocab_size, tone_size=tone_size, **config["model"]) + vocab_size=vocab_size, tone_size=tone_size, spk_num=spk_num, **config["model"]) if world_size > 1: model = DataParallel(model) print("model done!") @@ -184,6 +196,12 @@ def main(): parser.add_argument( "--tones-dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker-dict", + type=str, + default=None, + help="speaker id map file for multiple speaker model.") + # 这里可以多传入 max_epoch 等 args, rest = parser.parse_known_args() diff --git a/paddlespeech/t2s/models/speedyspeech/speedyspeech.py b/paddlespeech/t2s/models/speedyspeech/speedyspeech.py index ece5c279..ed085dfd 100644 --- a/paddlespeech/t2s/models/speedyspeech/speedyspeech.py +++ b/paddlespeech/t2s/models/speedyspeech/speedyspeech.py @@ -14,7 +14,7 @@ import numpy as np import paddle from paddle import nn - +import paddle.nn.functional as F from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding @@ -96,7 +96,7 @@ class TextEmbedding(nn.Layer): class SpeedySpeechEncoder(nn.Layer): def __init__(self, vocab_size, tone_size, hidden_size, kernel_size, - dilations): + dilations, spk_num=None): super().__init__() self.embedding = TextEmbedding( vocab_size, @@ -104,6 +104,15 @@ class SpeedySpeechEncoder(nn.Layer): tone_size, padding_idx=0, tone_padding_idx=0) + + if spk_num: + self.spk_emb = nn.Embedding( + num_embeddings=spk_num, + embedding_dim=hidden_size, + padding_idx=0) + else: + self.spk_emb = None + self.prenet = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.ReLU(), ) @@ -118,8 +127,10 @@ class SpeedySpeechEncoder(nn.Layer): nn.BatchNorm1D(hidden_size, data_format="NLC"), nn.Linear(hidden_size, hidden_size), ) - def forward(self, text, tones): + def forward(self, text, tones, spk_id=None): embedding = self.embedding(text, tones) + if self.spk_emb: + embedding += self.spk_emb(spk_id).unsqueeze(1) embedding = self.prenet(embedding) x = self.res_blocks(embedding) x = embedding + self.postnet1(x) @@ -171,11 +182,12 @@ class SpeedySpeech(nn.Layer): decoder_output_size, decoder_kernel_size, decoder_dilations, - tone_size=None, ): + tone_size=None, + spk_num=None): super().__init__() encoder = SpeedySpeechEncoder(vocab_size, tone_size, encoder_hidden_size, encoder_kernel_size, - encoder_dilations) + encoder_dilations, spk_num) duration_predictor = DurationPredictor(duration_predictor_hidden_size) decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size, decoder_kernel_size, decoder_dilations) @@ -184,13 +196,15 @@ class SpeedySpeech(nn.Layer): self.duration_predictor = duration_predictor self.decoder = decoder - def forward(self, text, tones, durations): + def forward(self, text, tones, durations, spk_id: paddle.Tensor=None): # input of embedding must be int64 text = paddle.cast(text, 'int64') tones = paddle.cast(tones, 'int64') + if spk_id is not None: + spk_id = paddle.cast(spk_id, 'int64') durations = paddle.cast(durations, 'int64') - encodings = self.encoder(text, tones) - # (B, T) + encodings = self.encoder(text, tones, spk_id) + pred_durations = self.duration_predictor(encodings.detach()) # expand encodings @@ -204,7 +218,7 @@ class SpeedySpeech(nn.Layer): decoded = self.decoder(encodings) return decoded, pred_durations - def inference(self, text, tones=None): + def inference(self, text, tones=None, spk_id=None): # text: [T] # tones: [T] # input of embedding must be int64 @@ -214,7 +228,8 @@ class SpeedySpeech(nn.Layer): tones = paddle.cast(tones, 'int64') tones = tones.unsqueeze(0) - encodings = self.encoder(text, tones) + encodings = self.encoder(text, tones, spk_id) + pred_durations = self.duration_predictor(encodings) # (1, T) durations_to_expand = paddle.round(pred_durations.exp()) durations_to_expand = (durations_to_expand).astype(paddle.int64) @@ -240,14 +255,13 @@ class SpeedySpeech(nn.Layer): decoded = self.decoder(encodings) return decoded[0] - class SpeedySpeechInference(nn.Layer): def __init__(self, normalizer, speedyspeech_model): super().__init__() self.normalizer = normalizer self.acoustic_model = speedyspeech_model - def forward(self, phones, tones): - normalized_mel = self.acoustic_model.inference(phones, tones) + def forward(self, phones, tones, spk_id=None): + normalized_mel = self.acoustic_model.inference(phones, tones, spk_id) logmel = self.normalizer.inverse(normalized_mel) return logmel diff --git a/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py b/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py index 6f9937a5..6b94ff9b 100644 --- a/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py +++ b/paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py @@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater): self.msg = "Rank: {}, ".format(dist.get_rank()) losses_dict = {} + # spk_id!=None in multiple spk speedyspeech + spk_id = batch["spk_id"] if "spk_id" in batch else None + decoded, predicted_durations = self.model( text=batch["phones"], tones=batch["tones"], - durations=batch["durations"]) + durations=batch["durations"], + spk_id=spk_id + ) target_mel = batch["feats"] spec_mask = F.sequence_mask( @@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator): self.msg = "Evaluate: " losses_dict = {} + spk_id = batch["spk_id"] if "spk_id" in batch else None + decoded, predicted_durations = self.model( text=batch["phones"], tones=batch["tones"], - durations=batch["durations"]) + durations=batch["durations"], + spk_id=spk_id + ) target_mel = batch["feats"] spec_mask = F.sequence_mask(