使用transformer的语音合成方式与其他模型对齐

pull/2449/head
吕志轩 3 years ago
parent 4ea647c50d
commit 6bba49df7c

@ -27,14 +27,11 @@ from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.phonectic import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
# remove [W:onnxruntime: xxx] from ort
ort.set_default_logger_severity(3)
model_alias = {
# acoustic model
@ -50,6 +47,10 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
"transformerTTS":
"paddlespeech.t2s.models.transformer_tts:TransformerTTS",
"transformerTTS_inference":
"paddlespeech.t2s.models.transformer_tts:TransformerTTSInference",
# voc
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
@ -71,10 +72,6 @@ model_alias = {
"paddlespeech.t2s.models.wavernn:WaveRNN",
"wavernn_inference":
"paddlespeech.t2s.models.wavernn:WaveRNNInference",
"erniesat":
"paddlespeech.t2s.models.ernie_sat:ErnieSAT",
"erniesat_inference":
"paddlespeech.t2s.models.ernie_sat:ErnieSATInference",
}
@ -82,17 +79,13 @@ def denorm(data, mean, std):
return data * std + mean
def norm(data, mean, std):
return (data - mean) / std
def get_chunks(data, block_size: int, pad_size: int):
def get_chunks(data, chunk_size: int, pad_size: int):
data_len = data.shape[1]
chunks = []
n = math.ceil(data_len / block_size)
n = math.ceil(data_len / chunk_size)
for i in range(n):
start = max(0, i * block_size - pad_size)
end = min((i + 1) * block_size + pad_size, data_len)
start = max(0, i * chunk_size - pad_size)
end = min((i + 1) * chunk_size + pad_size, data_len)
chunks.append(data[:, start:end, :])
return chunks
@ -109,8 +102,6 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
sentence = "".join(items[1:])
elif lang == 'en':
sentence = " ".join(items[1:])
elif lang == 'mix':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence))
return sentences
@ -122,11 +113,9 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
converters = {}
if am_name == 'fastspeech2':
fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk",
"mix"} and speaker_dict is not None:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
print("multiple speaker fastspeech2!")
fields += ["spk_id"]
elif voice_cloning:
@ -141,17 +130,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
if voice_cloning:
print("voice cloning!")
fields += ["spk_emb"]
elif am_name == 'erniesat':
fields = [
"utt_id", "text", "text_lengths", "speech", "speech_lengths",
"align_start", "align_end"
]
converters = {"speech": np.load}
else:
print("wrong am, please input right am!!!")
test_dataset = DataTable(
data=test_metadata, fields=fields, converters=converters)
test_dataset = DataTable(data=test_metadata, fields=fields)
return test_dataset
@ -164,73 +144,48 @@ def get_frontend(lang: str='zh',
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
elif lang == 'en':
frontend = English(phone_vocab_path=phones_dict)
elif lang == 'mix':
frontend = MixFrontend(
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
else:
print("wrong lang!")
print("frontend done!")
return frontend
def run_frontend(frontend: object,
text: str,
merge_sentences: bool=False,
get_tone_ids: bool=False,
lang: str='zh',
to_tensor: bool=True):
outs = dict()
if lang == 'zh':
input_ids = frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
outs.update({'tone_ids': tone_ids})
elif lang == 'en':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
outs.update({'phone_ids': phone_ids})
return outs
# dygraph
def get_am_inference(am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
return_am: bool=False):
def get_am_inference(
am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, ):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None
if tones_dict is not None:
with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None
if speaker_dict is not None:
with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = am_config.n_mels
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias)
if am_name == 'fastspeech2':
am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
@ -242,11 +197,8 @@ def get_am_inference(am: str='fastspeech2_csmsc',
**am_config["model"])
elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
elif am_name == 'erniesat':
elif am_name == 'transformerTTS':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
else:
print("wrong am, please input right am!!!")
am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval()
am_mu, am_std = np.load(am_stat)
@ -255,10 +207,8 @@ def get_am_inference(am: str='fastspeech2_csmsc',
am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am)
am_inference.eval()
if return_am:
return am_inference, am
else:
return am_inference
print("acoustic model done!")
return am_inference
def get_voc_inference(
@ -286,6 +236,7 @@ def get_voc_inference(
voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference = voc_inference_class(voc_normalizer, voc)
voc_inference.eval()
print("voc done!")
return voc_inference
@ -298,8 +249,7 @@ def am_to_static(am_inference,
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk",
"mix"} and speaker_dict is not None:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
@ -311,8 +261,7 @@ def am_to_static(am_inference,
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk",
"mix"} and speaker_dict is not None:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
@ -369,9 +318,9 @@ def get_predictor(model_dir: Optional[os.PathLike]=None,
def get_am_output(
input: str,
am_predictor: paddle.nn.Layer,
am: str,
frontend: object,
am_predictor,
am,
frontend,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
@ -379,23 +328,26 @@ def get_am_output(
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names()
get_spk_id = False
get_tone_ids = False
get_spk_id = False
if am_name == 'speedyspeech':
get_tone_ids = True
if am_dataset in {"aishell3", "vctk", "mix"} and speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True
spk_id = np.array([spk_id])
frontend_dict = run_frontend(
frontend=frontend,
text=input,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
lang=lang)
if lang == 'zh':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
elif lang == 'en':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
if get_tone_ids:
tone_ids = frontend_dict['tone_ids']
tone_ids = input_ids["tone_ids"]
tones = tone_ids[0].numpy()
tones_handle = am_predictor.get_input_handle(am_input_names[1])
tones_handle.reshape(tones.shape)
@ -404,7 +356,6 @@ def get_am_output(
spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
spk_id_handle.reshape(spk_id.shape)
spk_id_handle.copy_from_cpu(spk_id)
phone_ids = frontend_dict['phone_ids']
phones = phone_ids[0].numpy()
phones_handle = am_predictor.get_input_handle(am_input_names[0])
phones_handle.reshape(phones.shape)
@ -453,13 +404,13 @@ def get_streaming_am_output(input: str,
lang: str='zh',
merge_sentences: bool=True):
get_tone_ids = False
frontend_dict = run_frontend(
frontend=frontend,
text=input,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
lang=lang)
phone_ids = frontend_dict['phone_ids']
if lang == 'zh':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
else:
print("lang should be 'zh' here!")
phones = phone_ids[0].numpy()
am_encoder_infer_output = get_am_sublayer_output(
am_encoder_infer_predictor, input=phones)
@ -476,25 +427,26 @@ def get_streaming_am_output(input: str,
# onnx
def get_sess(model_path: Optional[os.PathLike],
def get_sess(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
device: str='cpu',
cpu_threads: int=1,
use_trt: bool=False):
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if 'gpu' in device.lower():
device_id = int(device.split(':')[1]) if len(
device.split(':')) == 2 else 0
if device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if use_trt:
provider_name = 'TensorrtExecutionProvider'
providers = ['TensorrtExecutionProvider']
else:
provider_name = 'CUDAExecutionProvider'
providers = [(provider_name, {'device_id': device_id})]
elif device.lower() == 'cpu':
providers = ['CUDAExecutionProvider']
elif device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession(
model_path, providers=providers, sess_options=sess_options)
model_dir, providers=providers, sess_options=sess_options)
return sess

@ -107,6 +107,13 @@ def evaluate(args):
if args.voice_cloning and "spk_emb" in datum:
spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
mel = am_inference(phone_ids, spk_emb=spk_emb)
elif am_name == 'transformerTTS':
phone_ids = paddle.to_tensor(datum["text"])
spk_emb = None
# multi speaker
if args.voice_cloning and "spk_emb" in datum:
spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
mel = am_inference(phone_ids, spk_emb=spk_emb)
# vocoder
wav = voc_inference(mel)
@ -136,7 +143,7 @@ def parse_args():
choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc',
'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix'
'tacotron2_ljspeech', 'tacotron2_aishell3', 'transformerTTS_csmsc'
],
help='Choose acoustic model type of tts task.')
parser.add_argument(

@ -25,7 +25,6 @@ from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.exps.syn_utils import voc_to_static
@ -50,7 +49,6 @@ def evaluate(args):
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
print("frontend done!")
# acoustic model
am_name = args.am[:args.am.rindex('_')]
@ -64,14 +62,13 @@ def evaluate(args):
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
print("acoustic model done!")
# vocoder
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
print("voc done!")
# whether dygraph to static
if args.inference_dir:
@ -81,6 +78,7 @@ def evaluate(args):
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
# vocoder
voc_inference = voc_to_static(
voc_inference=voc_inference,
@ -103,13 +101,24 @@ def evaluate(args):
T = 0
for utt_id, sentence in sentences:
with timer() as t:
frontend_dict = run_frontend(
frontend=frontend,
text=sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
lang=args.lang)
phone_ids = frontend_dict['phone_ids']
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'mix':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
with paddle.no_grad():
flags = 0
for i in range(len(phone_ids)):
@ -123,8 +132,8 @@ def evaluate(args):
else:
mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech':
part_tone_ids = frontend_dict['tone_ids'][i]
if am_dataset in {"aishell3", "vctk", "mix"}:
part_tone_ids = tone_ids[i]
if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, part_tone_ids,
spk_id)
@ -132,6 +141,8 @@ def evaluate(args):
mel = am_inference(part_phone_ids, part_tone_ids)
elif am_name == 'tacotron2':
mel = am_inference(part_phone_ids)
elif am_name == 'transformerTTS':
mel = am_inference(part_phone_ids)
# vocoder
wav = voc_inference(mel)
if flags == 0:
@ -165,7 +176,8 @@ def parse_args():
choices=[
'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc',
'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk',
'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix'
'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix',
'transformerTTS_csmsc'
],
help='Choose acoustic model type of tts task.')
parser.add_argument(

@ -13,28 +13,151 @@
# limitations under the License.
import argparse
import os
from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import jsonlines
import librosa
import numpy as np
import tqdm
import yaml
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.utils import str2bool
#from concurrent.futures import ThreadPoolExecutor
#from operator import itemgetter
#from typing import Any
#from typing import Dict
#from typing import List
#import jsonlines
#import librosa
#import numpy as np
#import tqdm
#from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
def process_sentence(config: Dict[str, Any],
fp: Path,
sentences: Dict,
output_dir: Path,
mel_extractor=None,
cut_sil: bool=True,
spk_emb_dir: Path=None):
utt_id = fp.stem
# for vctk
if utt_id.endswith("_mic2"):
utt_id = utt_id[:-5]
record = None
if utt_id in sentences:
# reading, resampling may occur
wav, _ = librosa.load(str(fp), sr=config.fs)
if len(wav.shape) != 1:
return record
max_value = np.abs(wav).max()
if max_value > 1.0:
wav = wav / max_value
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
assert np.abs(wav).max(
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
speaker = sentences[utt_id][2]
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
# little imprecise than use *.TextGrid directly
times = librosa.frames_to_time(
d_cumsum, sr=config.fs, hop_length=config.n_shift)
if cut_sil:
start = 0
end = d_cumsum[-1]
if phones[0] == "sil" and len(durations) > 1:
start = times[1]
durations = durations[1:]
phones = phones[1:]
if phones[-1] == 'sil' and len(durations) > 1:
end = times[-2]
durations = durations[:-1]
phones = phones[:-1]
sentences[utt_id][0] = phones
sentences[utt_id][1] = durations
start, end = librosa.time_to_samples([start, end], sr=config.fs)
wav = wav[start:end]
# extract mel feats
logmel = mel_extractor.get_log_mel_fbank(wav)
# change duration according to mel_length
compare_duration_and_mel_length(sentences, utt_id, logmel)
# utt_id may be popped in compare_duration_and_mel_length
if utt_id not in sentences:
return None
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
num_frames = logmel.shape[0]
assert sum(durations) == num_frames
mel_dir = output_dir / "data_speech"
mel_dir.mkdir(parents=True, exist_ok=True)
mel_path = mel_dir / (utt_id + "_speech.npy")
np.save(mel_path, logmel)
record = {
"utt_id": utt_id,
"phones": phones,
"text_lengths": len(phones),
"speech_lengths": num_frames,
"speech": str(mel_path),
"speaker": speaker
}
if spk_emb_dir:
if speaker in os.listdir(spk_emb_dir):
embed_name = utt_id + ".npy"
embed_path = spk_emb_dir / speaker / embed_name
if embed_path.is_file():
record["spk_emb"] = str(embed_path)
else:
return None
return record
def process_sentences(config,
fps: List[Path],
sentences: Dict,
output_dir: Path,
mel_extractor=None,
nprocs: int=1,
cut_sil: bool=True,
spk_emb_dir: Path=None):
if nprocs == 1:
results = []
for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(
config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir)
if record:
results.append(record)
else:
with ThreadPoolExecutor(nprocs) as pool:
futures = []
with tqdm.tqdm(total=len(fps)) as progress:
for fp in fps:
future = pool.submit(process_sentence, config, fp,
sentences, output_dir, mel_extractor,
cut_sil, spk_emb_dir)
future.add_done_callback(lambda p: progress.update())
futures.append(future)
results = []
for ft in futures:
record = ft.result()
if record:
results.append(record)
results.sort(key=itemgetter("utt_id"))
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
for item in results:
writer.write(item)
print("Done")
def main():
@ -59,7 +182,7 @@ def main():
parser.add_argument(
"--dur-file", default=None, type=str, help="path to durations.txt.")
parser.add_argument("--config", type=str, help="transformer config file.")
parser.add_argument("--config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.")

Loading…
Cancel
Save