使用transformer的语音合成方式与其他模型对齐

pull/2449/head
吕志轩 3 years ago
parent 4ea647c50d
commit 6bba49df7c

@ -27,14 +27,11 @@ from paddle import jit
from paddle.static import InputSpec from paddle.static import InputSpec
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.phonectic import English
from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
# remove [W:onnxruntime: xxx] from ort
ort.set_default_logger_severity(3)
model_alias = { model_alias = {
# acoustic model # acoustic model
@ -50,6 +47,10 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2", "paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference": "tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference", "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
"transformerTTS":
"paddlespeech.t2s.models.transformer_tts:TransformerTTS",
"transformerTTS_inference":
"paddlespeech.t2s.models.transformer_tts:TransformerTTSInference",
# voc # voc
"pwgan": "pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
@ -71,10 +72,6 @@ model_alias = {
"paddlespeech.t2s.models.wavernn:WaveRNN", "paddlespeech.t2s.models.wavernn:WaveRNN",
"wavernn_inference": "wavernn_inference":
"paddlespeech.t2s.models.wavernn:WaveRNNInference", "paddlespeech.t2s.models.wavernn:WaveRNNInference",
"erniesat":
"paddlespeech.t2s.models.ernie_sat:ErnieSAT",
"erniesat_inference":
"paddlespeech.t2s.models.ernie_sat:ErnieSATInference",
} }
@ -82,17 +79,13 @@ def denorm(data, mean, std):
return data * std + mean return data * std + mean
def norm(data, mean, std): def get_chunks(data, chunk_size: int, pad_size: int):
return (data - mean) / std
def get_chunks(data, block_size: int, pad_size: int):
data_len = data.shape[1] data_len = data.shape[1]
chunks = [] chunks = []
n = math.ceil(data_len / block_size) n = math.ceil(data_len / chunk_size)
for i in range(n): for i in range(n):
start = max(0, i * block_size - pad_size) start = max(0, i * chunk_size - pad_size)
end = min((i + 1) * block_size + pad_size, data_len) end = min((i + 1) * chunk_size + pad_size, data_len)
chunks.append(data[:, start:end, :]) chunks.append(data[:, start:end, :])
return chunks return chunks
@ -109,8 +102,6 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
sentence = "".join(items[1:]) sentence = "".join(items[1:])
elif lang == 'en': elif lang == 'en':
sentence = " ".join(items[1:]) sentence = " ".join(items[1:])
elif lang == 'mix':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence)) sentences.append((utt_id, sentence))
return sentences return sentences
@ -122,11 +113,9 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
converters = {}
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk", if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
"mix"} and speaker_dict is not None:
print("multiple speaker fastspeech2!") print("multiple speaker fastspeech2!")
fields += ["spk_id"] fields += ["spk_id"]
elif voice_cloning: elif voice_cloning:
@ -141,17 +130,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
if voice_cloning: if voice_cloning:
print("voice cloning!") print("voice cloning!")
fields += ["spk_emb"] fields += ["spk_emb"]
elif am_name == 'erniesat':
fields = [
"utt_id", "text", "text_lengths", "speech", "speech_lengths",
"align_start", "align_end"
]
converters = {"speech": np.load}
else:
print("wrong am, please input right am!!!")
test_dataset = DataTable( test_dataset = DataTable(data=test_metadata, fields=fields)
data=test_metadata, fields=fields, converters=converters)
return test_dataset return test_dataset
@ -164,73 +144,48 @@ def get_frontend(lang: str='zh',
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
elif lang == 'en': elif lang == 'en':
frontend = English(phone_vocab_path=phones_dict) frontend = English(phone_vocab_path=phones_dict)
elif lang == 'mix':
frontend = MixFrontend(
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
else: else:
print("wrong lang!") print("wrong lang!")
print("frontend done!")
return frontend return frontend
def run_frontend(frontend: object,
text: str,
merge_sentences: bool=False,
get_tone_ids: bool=False,
lang: str='zh',
to_tensor: bool=True):
outs = dict()
if lang == 'zh':
input_ids = frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
outs.update({'tone_ids': tone_ids})
elif lang == 'en':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
outs.update({'phone_ids': phone_ids})
return outs
# dygraph # dygraph
def get_am_inference(am: str='fastspeech2_csmsc', def get_am_inference(
am_config: CfgNode=None, am: str='fastspeech2_csmsc',
am_ckpt: Optional[os.PathLike]=None, am_config: CfgNode=None,
am_stat: Optional[os.PathLike]=None, am_ckpt: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None, am_stat: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None, phones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None,
return_am: bool=False): speaker_dict: Optional[os.PathLike]=None, ):
with open(phones_dict, "r") as f: with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None tone_size = None
if tones_dict is not None: if tones_dict is not None:
with open(tones_dict, "r") as f: with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None spk_num = None
if speaker_dict is not None: if speaker_dict is not None:
with open(speaker_dict, 'rt') as f: with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = am_config.n_mels odim = am_config.n_mels
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias) am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias)
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
am = am_class( am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
@ -242,11 +197,8 @@ def get_am_inference(am: str='fastspeech2_csmsc',
**am_config["model"]) **am_config["model"])
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
elif am_name == 'erniesat': elif am_name == 'transformerTTS':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
else:
print("wrong am, please input right am!!!")
am.set_state_dict(paddle.load(am_ckpt)["main_params"]) am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval() am.eval()
am_mu, am_std = np.load(am_stat) am_mu, am_std = np.load(am_stat)
@ -255,10 +207,8 @@ def get_am_inference(am: str='fastspeech2_csmsc',
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am) am_inference = am_inference_class(am_normalizer, am)
am_inference.eval() am_inference.eval()
if return_am: print("acoustic model done!")
return am_inference, am return am_inference
else:
return am_inference
def get_voc_inference( def get_voc_inference(
@ -286,6 +236,7 @@ def get_voc_inference(
voc_normalizer = ZScore(voc_mu, voc_std) voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference = voc_inference_class(voc_normalizer, voc) voc_inference = voc_inference_class(voc_normalizer, voc)
voc_inference.eval() voc_inference.eval()
print("voc done!")
return voc_inference return voc_inference
@ -298,8 +249,7 @@ def am_to_static(am_inference,
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk", if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
"mix"} and speaker_dict is not None:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
@ -311,8 +261,7 @@ def am_to_static(am_inference,
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk", if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
"mix"} and speaker_dict is not None:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
@ -369,9 +318,9 @@ def get_predictor(model_dir: Optional[os.PathLike]=None,
def get_am_output( def get_am_output(
input: str, input: str,
am_predictor: paddle.nn.Layer, am_predictor,
am: str, am,
frontend: object, frontend,
lang: str='zh', lang: str='zh',
merge_sentences: bool=True, merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None,
@ -379,23 +328,26 @@ def get_am_output(
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names() am_input_names = am_predictor.get_input_names()
get_spk_id = False
get_tone_ids = False get_tone_ids = False
get_spk_id = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if am_dataset in {"aishell3", "vctk", "mix"} and speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True get_spk_id = True
spk_id = np.array([spk_id]) spk_id = np.array([spk_id])
if lang == 'zh':
frontend_dict = run_frontend( input_ids = frontend.get_input_ids(
frontend=frontend, input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
text=input, phone_ids = input_ids["phone_ids"]
merge_sentences=merge_sentences, elif lang == 'en':
get_tone_ids=get_tone_ids, input_ids = frontend.get_input_ids(
lang=lang) input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
if get_tone_ids: if get_tone_ids:
tone_ids = frontend_dict['tone_ids'] tone_ids = input_ids["tone_ids"]
tones = tone_ids[0].numpy() tones = tone_ids[0].numpy()
tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle = am_predictor.get_input_handle(am_input_names[1])
tones_handle.reshape(tones.shape) tones_handle.reshape(tones.shape)
@ -404,7 +356,6 @@ def get_am_output(
spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
spk_id_handle.reshape(spk_id.shape) spk_id_handle.reshape(spk_id.shape)
spk_id_handle.copy_from_cpu(spk_id) spk_id_handle.copy_from_cpu(spk_id)
phone_ids = frontend_dict['phone_ids']
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle = am_predictor.get_input_handle(am_input_names[0])
phones_handle.reshape(phones.shape) phones_handle.reshape(phones.shape)
@ -453,13 +404,13 @@ def get_streaming_am_output(input: str,
lang: str='zh', lang: str='zh',
merge_sentences: bool=True): merge_sentences: bool=True):
get_tone_ids = False get_tone_ids = False
frontend_dict = run_frontend( if lang == 'zh':
frontend=frontend, input_ids = frontend.get_input_ids(
text=input, input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
merge_sentences=merge_sentences, phone_ids = input_ids["phone_ids"]
get_tone_ids=get_tone_ids, else:
lang=lang) print("lang should be 'zh' here!")
phone_ids = frontend_dict['phone_ids']
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
am_encoder_infer_output = get_am_sublayer_output( am_encoder_infer_output = get_am_sublayer_output(
am_encoder_infer_predictor, input=phones) am_encoder_infer_predictor, input=phones)
@ -476,25 +427,26 @@ def get_streaming_am_output(input: str,
# onnx # onnx
def get_sess(model_path: Optional[os.PathLike], def get_sess(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
device: str='cpu', device: str='cpu',
cpu_threads: int=1, cpu_threads: int=1,
use_trt: bool=False): use_trt: bool=False):
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if 'gpu' in device.lower():
device_id = int(device.split(':')[1]) if len( if device == "gpu":
device.split(':')) == 2 else 0
# fastspeech2/mb_melgan can't use trt now! # fastspeech2/mb_melgan can't use trt now!
if use_trt: if use_trt:
provider_name = 'TensorrtExecutionProvider' providers = ['TensorrtExecutionProvider']
else: else:
provider_name = 'CUDAExecutionProvider' providers = ['CUDAExecutionProvider']
providers = [(provider_name, {'device_id': device_id})] elif device == "cpu":
elif device.lower() == 'cpu':
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = cpu_threads sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession( sess = ort.InferenceSession(
model_path, providers=providers, sess_options=sess_options) model_dir, providers=providers, sess_options=sess_options)
return sess return sess

@ -107,6 +107,13 @@ def evaluate(args):
if args.voice_cloning and "spk_emb" in datum: if args.voice_cloning and "spk_emb" in datum:
spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
mel = am_inference(phone_ids, spk_emb=spk_emb) mel = am_inference(phone_ids, spk_emb=spk_emb)
elif am_name == 'transformerTTS':
phone_ids = paddle.to_tensor(datum["text"])
spk_emb = None
# multi speaker
if args.voice_cloning and "spk_emb" in datum:
spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
mel = am_inference(phone_ids, spk_emb=spk_emb)
# vocoder # vocoder
wav = voc_inference(mel) wav = voc_inference(mel)
@ -136,7 +143,7 @@ def parse_args():
choices=[ choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc', 'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc',
'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix' 'tacotron2_ljspeech', 'tacotron2_aishell3', 'transformerTTS_csmsc'
], ],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(

@ -25,7 +25,6 @@ from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.exps.syn_utils import voc_to_static
@ -50,7 +49,6 @@ def evaluate(args):
lang=args.lang, lang=args.lang,
phones_dict=args.phones_dict, phones_dict=args.phones_dict,
tones_dict=args.tones_dict) tones_dict=args.tones_dict)
print("frontend done!")
# acoustic model # acoustic model
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
@ -64,14 +62,13 @@ def evaluate(args):
phones_dict=args.phones_dict, phones_dict=args.phones_dict,
tones_dict=args.tones_dict, tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict) speaker_dict=args.speaker_dict)
print("acoustic model done!")
# vocoder # vocoder
voc_inference = get_voc_inference( voc_inference = get_voc_inference(
voc=args.voc, voc=args.voc,
voc_config=voc_config, voc_config=voc_config,
voc_ckpt=args.voc_ckpt, voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat) voc_stat=args.voc_stat)
print("voc done!")
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
@ -81,6 +78,7 @@ def evaluate(args):
am=args.am, am=args.am,
inference_dir=args.inference_dir, inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict) speaker_dict=args.speaker_dict)
# vocoder # vocoder
voc_inference = voc_to_static( voc_inference = voc_to_static(
voc_inference=voc_inference, voc_inference=voc_inference,
@ -103,13 +101,24 @@ def evaluate(args):
T = 0 T = 0
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
frontend_dict = run_frontend( if args.lang == 'zh':
frontend=frontend, input_ids = frontend.get_input_ids(
text=sentence, sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids, get_tone_ids=get_tone_ids)
lang=args.lang) phone_ids = input_ids["phone_ids"]
phone_ids = frontend_dict['phone_ids'] if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'mix':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
with paddle.no_grad(): with paddle.no_grad():
flags = 0 flags = 0
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
@ -123,8 +132,8 @@ def evaluate(args):
else: else:
mel = am_inference(part_phone_ids) mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
part_tone_ids = frontend_dict['tone_ids'][i] part_tone_ids = tone_ids[i]
if am_dataset in {"aishell3", "vctk", "mix"}: if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, part_tone_ids, mel = am_inference(part_phone_ids, part_tone_ids,
spk_id) spk_id)
@ -132,6 +141,8 @@ def evaluate(args):
mel = am_inference(part_phone_ids, part_tone_ids) mel = am_inference(part_phone_ids, part_tone_ids)
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
mel = am_inference(part_phone_ids) mel = am_inference(part_phone_ids)
elif am_name == 'transformerTTS':
mel = am_inference(part_phone_ids)
# vocoder # vocoder
wav = voc_inference(mel) wav = voc_inference(mel)
if flags == 0: if flags == 0:
@ -165,7 +176,8 @@ def parse_args():
choices=[ choices=[
'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc', 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc',
'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk',
'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix' 'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix',
'transformerTTS_csmsc'
], ],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(

@ -13,28 +13,151 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os import os
from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter
from pathlib import Path from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import jsonlines
import librosa
import numpy as np
import tqdm
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
#from concurrent.futures import ThreadPoolExecutor
#from operator import itemgetter def process_sentence(config: Dict[str, Any],
#from typing import Any fp: Path,
#from typing import Dict sentences: Dict,
#from typing import List output_dir: Path,
#import jsonlines mel_extractor=None,
#import librosa cut_sil: bool=True,
#import numpy as np spk_emb_dir: Path=None):
#import tqdm utt_id = fp.stem
#from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length # for vctk
if utt_id.endswith("_mic2"):
utt_id = utt_id[:-5]
record = None
if utt_id in sentences:
# reading, resampling may occur
wav, _ = librosa.load(str(fp), sr=config.fs)
if len(wav.shape) != 1:
return record
max_value = np.abs(wav).max()
if max_value > 1.0:
wav = wav / max_value
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
assert np.abs(wav).max(
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
speaker = sentences[utt_id][2]
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
# little imprecise than use *.TextGrid directly
times = librosa.frames_to_time(
d_cumsum, sr=config.fs, hop_length=config.n_shift)
if cut_sil:
start = 0
end = d_cumsum[-1]
if phones[0] == "sil" and len(durations) > 1:
start = times[1]
durations = durations[1:]
phones = phones[1:]
if phones[-1] == 'sil' and len(durations) > 1:
end = times[-2]
durations = durations[:-1]
phones = phones[:-1]
sentences[utt_id][0] = phones
sentences[utt_id][1] = durations
start, end = librosa.time_to_samples([start, end], sr=config.fs)
wav = wav[start:end]
# extract mel feats
logmel = mel_extractor.get_log_mel_fbank(wav)
# change duration according to mel_length
compare_duration_and_mel_length(sentences, utt_id, logmel)
# utt_id may be popped in compare_duration_and_mel_length
if utt_id not in sentences:
return None
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
num_frames = logmel.shape[0]
assert sum(durations) == num_frames
mel_dir = output_dir / "data_speech"
mel_dir.mkdir(parents=True, exist_ok=True)
mel_path = mel_dir / (utt_id + "_speech.npy")
np.save(mel_path, logmel)
record = {
"utt_id": utt_id,
"phones": phones,
"text_lengths": len(phones),
"speech_lengths": num_frames,
"speech": str(mel_path),
"speaker": speaker
}
if spk_emb_dir:
if speaker in os.listdir(spk_emb_dir):
embed_name = utt_id + ".npy"
embed_path = spk_emb_dir / speaker / embed_name
if embed_path.is_file():
record["spk_emb"] = str(embed_path)
else:
return None
return record
def process_sentences(config,
fps: List[Path],
sentences: Dict,
output_dir: Path,
mel_extractor=None,
nprocs: int=1,
cut_sil: bool=True,
spk_emb_dir: Path=None):
if nprocs == 1:
results = []
for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(
config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir)
if record:
results.append(record)
else:
with ThreadPoolExecutor(nprocs) as pool:
futures = []
with tqdm.tqdm(total=len(fps)) as progress:
for fp in fps:
future = pool.submit(process_sentence, config, fp,
sentences, output_dir, mel_extractor,
cut_sil, spk_emb_dir)
future.add_done_callback(lambda p: progress.update())
futures.append(future)
results = []
for ft in futures:
record = ft.result()
if record:
results.append(record)
results.sort(key=itemgetter("utt_id"))
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
for item in results:
writer.write(item)
print("Done")
def main(): def main():
@ -59,7 +182,7 @@ def main():
parser.add_argument( parser.add_argument(
"--dur-file", default=None, type=str, help="path to durations.txt.") "--dur-file", default=None, type=str, help="path to durations.txt.")
parser.add_argument("--config", type=str, help="transformer config file.") parser.add_argument("--config", type=str, help="fastspeech2 config file.")
parser.add_argument( parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.") "--num-cpu", type=int, default=1, help="number of process.")

Loading…
Cancel
Save