diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 15d8dfb78..e5afea800 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -27,14 +27,11 @@ from paddle import jit from paddle.static import InputSpec from yacs.config import CfgNode +from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.datasets.data_table import DataTable -from paddlespeech.t2s.frontend import English -from paddlespeech.t2s.frontend.mix_frontend import MixFrontend +from paddlespeech.t2s.frontend.phonectic import English from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore -from paddlespeech.utils.dynamic_import import dynamic_import -# remove [W:onnxruntime: xxx] from ort -ort.set_default_logger_severity(3) model_alias = { # acoustic model @@ -50,6 +47,10 @@ model_alias = { "paddlespeech.t2s.models.tacotron2:Tacotron2", "tacotron2_inference": "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", + "transformerTTS": + "paddlespeech.t2s.models.transformer_tts:TransformerTTS", + "transformerTTS_inference": + "paddlespeech.t2s.models.transformer_tts:TransformerTTSInference", # voc "pwgan": "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", @@ -71,10 +72,6 @@ model_alias = { "paddlespeech.t2s.models.wavernn:WaveRNN", "wavernn_inference": "paddlespeech.t2s.models.wavernn:WaveRNNInference", - "erniesat": - "paddlespeech.t2s.models.ernie_sat:ErnieSAT", - "erniesat_inference": - "paddlespeech.t2s.models.ernie_sat:ErnieSATInference", } @@ -82,17 +79,13 @@ def denorm(data, mean, std): return data * std + mean -def norm(data, mean, std): - return (data - mean) / std - - -def get_chunks(data, block_size: int, pad_size: int): +def get_chunks(data, chunk_size: int, pad_size: int): data_len = data.shape[1] chunks = [] - n = math.ceil(data_len / block_size) + n = math.ceil(data_len / chunk_size) for i in range(n): - start = max(0, i * block_size - pad_size) - end = min((i + 1) * block_size + pad_size, data_len) + start = max(0, i * chunk_size - pad_size) + end = min((i + 1) * chunk_size + pad_size, data_len) chunks.append(data[:, start:end, :]) return chunks @@ -109,8 +102,6 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'): sentence = "".join(items[1:]) elif lang == 'en': sentence = " ".join(items[1:]) - elif lang == 'mix': - sentence = " ".join(items[1:]) sentences.append((utt_id, sentence)) return sentences @@ -122,11 +113,9 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] - converters = {} if am_name == 'fastspeech2': fields = ["utt_id", "text"] - if am_dataset in {"aishell3", "vctk", - "mix"} and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: print("multiple speaker fastspeech2!") fields += ["spk_id"] elif voice_cloning: @@ -141,17 +130,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], if voice_cloning: print("voice cloning!") fields += ["spk_emb"] - elif am_name == 'erniesat': - fields = [ - "utt_id", "text", "text_lengths", "speech", "speech_lengths", - "align_start", "align_end" - ] - converters = {"speech": np.load} - else: - print("wrong am, please input right am!!!") - test_dataset = DataTable( - data=test_metadata, fields=fields, converters=converters) + test_dataset = DataTable(data=test_metadata, fields=fields) return test_dataset @@ -164,73 +144,48 @@ def get_frontend(lang: str='zh', phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) elif lang == 'en': frontend = English(phone_vocab_path=phones_dict) - elif lang == 'mix': - frontend = MixFrontend( - phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) else: print("wrong lang!") + print("frontend done!") return frontend -def run_frontend(frontend: object, - text: str, - merge_sentences: bool=False, - get_tone_ids: bool=False, - lang: str='zh', - to_tensor: bool=True): - outs = dict() - if lang == 'zh': - input_ids = frontend.get_input_ids( - text, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids, - to_tensor=to_tensor) - phone_ids = input_ids["phone_ids"] - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - outs.update({'tone_ids': tone_ids}) - elif lang == 'en': - input_ids = frontend.get_input_ids( - text, merge_sentences=merge_sentences, to_tensor=to_tensor) - phone_ids = input_ids["phone_ids"] - elif lang == 'mix': - input_ids = frontend.get_input_ids( - text, merge_sentences=merge_sentences, to_tensor=to_tensor) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en', 'mix'}!") - outs.update({'phone_ids': phone_ids}) - return outs - - # dygraph -def get_am_inference(am: str='fastspeech2_csmsc', - am_config: CfgNode=None, - am_ckpt: Optional[os.PathLike]=None, - am_stat: Optional[os.PathLike]=None, - phones_dict: Optional[os.PathLike]=None, - tones_dict: Optional[os.PathLike]=None, - speaker_dict: Optional[os.PathLike]=None, - return_am: bool=False): +def get_am_inference( + am: str='fastspeech2_csmsc', + am_config: CfgNode=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, ): with open(phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + tone_size = None if tones_dict is not None: with open(tones_dict, "r") as f: tone_id = [line.strip().split() for line in f.readlines()] tone_size = len(tone_id) + print("tone_size:", tone_size) + spk_num = None if speaker_dict is not None: with open(speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] spk_num = len(spk_id) + print("spk_num:", spk_num) + odim = am_config.n_mels # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] + am_class = dynamic_import(am_name, model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias) + if am_name == 'fastspeech2': am = am_class( idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) @@ -242,11 +197,8 @@ def get_am_inference(am: str='fastspeech2_csmsc', **am_config["model"]) elif am_name == 'tacotron2': am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - elif am_name == 'erniesat': + elif am_name == 'transformerTTS': am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - else: - print("wrong am, please input right am!!!") - am.set_state_dict(paddle.load(am_ckpt)["main_params"]) am.eval() am_mu, am_std = np.load(am_stat) @@ -255,10 +207,8 @@ def get_am_inference(am: str='fastspeech2_csmsc', am_normalizer = ZScore(am_mu, am_std) am_inference = am_inference_class(am_normalizer, am) am_inference.eval() - if return_am: - return am_inference, am - else: - return am_inference + print("acoustic model done!") + return am_inference def get_voc_inference( @@ -286,6 +236,7 @@ def get_voc_inference( voc_normalizer = ZScore(voc_mu, voc_std) voc_inference = voc_inference_class(voc_normalizer, voc) voc_inference.eval() + print("voc done!") return voc_inference @@ -298,8 +249,7 @@ def am_to_static(am_inference, am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] if am_name == 'fastspeech2': - if am_dataset in {"aishell3", "vctk", - "mix"} and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -311,8 +261,7 @@ def am_to_static(am_inference, am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) elif am_name == 'speedyspeech': - if am_dataset in {"aishell3", "vctk", - "mix"} and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -369,9 +318,9 @@ def get_predictor(model_dir: Optional[os.PathLike]=None, def get_am_output( input: str, - am_predictor: paddle.nn.Layer, - am: str, - frontend: object, + am_predictor, + am, + frontend, lang: str='zh', merge_sentences: bool=True, speaker_dict: Optional[os.PathLike]=None, @@ -379,23 +328,26 @@ def get_am_output( am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] am_input_names = am_predictor.get_input_names() - get_spk_id = False get_tone_ids = False + get_spk_id = False if am_name == 'speedyspeech': get_tone_ids = True - if am_dataset in {"aishell3", "vctk", "mix"} and speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict: get_spk_id = True spk_id = np.array([spk_id]) - - frontend_dict = run_frontend( - frontend=frontend, - text=input, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids, - lang=lang) + if lang == 'zh': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + elif lang == 'en': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") if get_tone_ids: - tone_ids = frontend_dict['tone_ids'] + tone_ids = input_ids["tone_ids"] tones = tone_ids[0].numpy() tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle.reshape(tones.shape) @@ -404,7 +356,6 @@ def get_am_output( spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) spk_id_handle.reshape(spk_id.shape) spk_id_handle.copy_from_cpu(spk_id) - phone_ids = frontend_dict['phone_ids'] phones = phone_ids[0].numpy() phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle.reshape(phones.shape) @@ -453,13 +404,13 @@ def get_streaming_am_output(input: str, lang: str='zh', merge_sentences: bool=True): get_tone_ids = False - frontend_dict = run_frontend( - frontend=frontend, - text=input, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids, - lang=lang) - phone_ids = frontend_dict['phone_ids'] + if lang == 'zh': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + else: + print("lang should be 'zh' here!") + phones = phone_ids[0].numpy() am_encoder_infer_output = get_am_sublayer_output( am_encoder_infer_predictor, input=phones) @@ -476,25 +427,26 @@ def get_streaming_am_output(input: str, # onnx -def get_sess(model_path: Optional[os.PathLike], +def get_sess(model_dir: Optional[os.PathLike]=None, + model_file: Optional[os.PathLike]=None, device: str='cpu', cpu_threads: int=1, use_trt: bool=False): + + model_dir = str(Path(model_dir) / model_file) sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - if 'gpu' in device.lower(): - device_id = int(device.split(':')[1]) if len( - device.split(':')) == 2 else 0 + + if device == "gpu": # fastspeech2/mb_melgan can't use trt now! if use_trt: - provider_name = 'TensorrtExecutionProvider' + providers = ['TensorrtExecutionProvider'] else: - provider_name = 'CUDAExecutionProvider' - providers = [(provider_name, {'device_id': device_id})] - elif device.lower() == 'cpu': + providers = ['CUDAExecutionProvider'] + elif device == "cpu": providers = ['CPUExecutionProvider'] sess_options.intra_op_num_threads = cpu_threads sess = ort.InferenceSession( - model_path, providers=providers, sess_options=sess_options) + model_dir, providers=providers, sess_options=sess_options) return sess diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index a8e18150e..260953334 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -107,6 +107,13 @@ def evaluate(args): if args.voice_cloning and "spk_emb" in datum: spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) mel = am_inference(phone_ids, spk_emb=spk_emb) + elif am_name == 'transformerTTS': + phone_ids = paddle.to_tensor(datum["text"]) + spk_emb = None + # multi speaker + if args.voice_cloning and "spk_emb" in datum: + spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) + mel = am_inference(phone_ids, spk_emb=spk_emb) # vocoder wav = voc_inference(mel) @@ -136,7 +143,7 @@ def parse_args(): choices=[ 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc', - 'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix' + 'tacotron2_ljspeech', 'tacotron2_aishell3', 'transformerTTS_csmsc' ], help='Choose acoustic model type of tts task.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 9ce8286fb..677e34e28 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -25,7 +25,6 @@ from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_voc_inference -from paddlespeech.t2s.exps.syn_utils import run_frontend from paddlespeech.t2s.exps.syn_utils import voc_to_static @@ -50,7 +49,6 @@ def evaluate(args): lang=args.lang, phones_dict=args.phones_dict, tones_dict=args.tones_dict) - print("frontend done!") # acoustic model am_name = args.am[:args.am.rindex('_')] @@ -64,14 +62,13 @@ def evaluate(args): phones_dict=args.phones_dict, tones_dict=args.tones_dict, speaker_dict=args.speaker_dict) - print("acoustic model done!") + # vocoder voc_inference = get_voc_inference( voc=args.voc, voc_config=voc_config, voc_ckpt=args.voc_ckpt, voc_stat=args.voc_stat) - print("voc done!") # whether dygraph to static if args.inference_dir: @@ -81,6 +78,7 @@ def evaluate(args): am=args.am, inference_dir=args.inference_dir, speaker_dict=args.speaker_dict) + # vocoder voc_inference = voc_to_static( voc_inference=voc_inference, @@ -103,13 +101,24 @@ def evaluate(args): T = 0 for utt_id, sentence in sentences: with timer() as t: - frontend_dict = run_frontend( - frontend=frontend, - text=sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids, - lang=args.lang) - phone_ids = frontend_dict['phone_ids'] + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'mix': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en', 'mix'}!") with paddle.no_grad(): flags = 0 for i in range(len(phone_ids)): @@ -123,8 +132,8 @@ def evaluate(args): else: mel = am_inference(part_phone_ids) elif am_name == 'speedyspeech': - part_tone_ids = frontend_dict['tone_ids'][i] - if am_dataset in {"aishell3", "vctk", "mix"}: + part_tone_ids = tone_ids[i] + if am_dataset in {"aishell3", "vctk"}: spk_id = paddle.to_tensor(args.spk_id) mel = am_inference(part_phone_ids, part_tone_ids, spk_id) @@ -132,6 +141,8 @@ def evaluate(args): mel = am_inference(part_phone_ids, part_tone_ids) elif am_name == 'tacotron2': mel = am_inference(part_phone_ids) + elif am_name == 'transformerTTS': + mel = am_inference(part_phone_ids) # vocoder wav = voc_inference(mel) if flags == 0: @@ -165,7 +176,8 @@ def parse_args(): choices=[ 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', - 'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix' + 'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix', + 'transformerTTS_csmsc' ], help='Choose acoustic model type of tts task.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py b/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py index 5f9fd9215..c27b9769b 100644 --- a/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py +++ b/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py @@ -13,28 +13,151 @@ # limitations under the License. import argparse import os +from concurrent.futures import ThreadPoolExecutor +from operator import itemgetter from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +import jsonlines +import librosa +import numpy as np +import tqdm import yaml from yacs.config import CfgNode from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length from paddlespeech.t2s.datasets.preprocess_utils import get_input_token from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.utils import str2bool -#from concurrent.futures import ThreadPoolExecutor -#from operator import itemgetter -#from typing import Any -#from typing import Dict -#from typing import List -#import jsonlines -#import librosa -#import numpy as np -#import tqdm -#from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length + +def process_sentence(config: Dict[str, Any], + fp: Path, + sentences: Dict, + output_dir: Path, + mel_extractor=None, + cut_sil: bool=True, + spk_emb_dir: Path=None): + utt_id = fp.stem + # for vctk + if utt_id.endswith("_mic2"): + utt_id = utt_id[:-5] + record = None + if utt_id in sentences: + # reading, resampling may occur + wav, _ = librosa.load(str(fp), sr=config.fs) + if len(wav.shape) != 1: + return record + max_value = np.abs(wav).max() + if max_value > 1.0: + wav = wav / max_value + assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(wav).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + speaker = sentences[utt_id][2] + d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant') + # little imprecise than use *.TextGrid directly + times = librosa.frames_to_time( + d_cumsum, sr=config.fs, hop_length=config.n_shift) + if cut_sil: + start = 0 + end = d_cumsum[-1] + if phones[0] == "sil" and len(durations) > 1: + start = times[1] + durations = durations[1:] + phones = phones[1:] + if phones[-1] == 'sil' and len(durations) > 1: + end = times[-2] + durations = durations[:-1] + phones = phones[:-1] + sentences[utt_id][0] = phones + sentences[utt_id][1] = durations + start, end = librosa.time_to_samples([start, end], sr=config.fs) + wav = wav[start:end] + # extract mel feats + logmel = mel_extractor.get_log_mel_fbank(wav) + # change duration according to mel_length + compare_duration_and_mel_length(sentences, utt_id, logmel) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + num_frames = logmel.shape[0] + assert sum(durations) == num_frames + mel_dir = output_dir / "data_speech" + mel_dir.mkdir(parents=True, exist_ok=True) + mel_path = mel_dir / (utt_id + "_speech.npy") + np.save(mel_path, logmel) + record = { + "utt_id": utt_id, + "phones": phones, + "text_lengths": len(phones), + "speech_lengths": num_frames, + "speech": str(mel_path), + "speaker": speaker + } + if spk_emb_dir: + if speaker in os.listdir(spk_emb_dir): + embed_name = utt_id + ".npy" + embed_path = spk_emb_dir / speaker / embed_name + if embed_path.is_file(): + record["spk_emb"] = str(embed_path) + else: + return None + return record + + +def process_sentences(config, + fps: List[Path], + sentences: Dict, + output_dir: Path, + mel_extractor=None, + nprocs: int=1, + cut_sil: bool=True, + spk_emb_dir: Path=None): + if nprocs == 1: + results = [] + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir) + if record: + results.append(record) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp in fps: + future = pool.submit(process_sentence, config, fp, + sentences, output_dir, mel_extractor, + cut_sil, spk_emb_dir) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + record = ft.result() + if record: + results.append(record) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) + print("Done") def main(): @@ -59,7 +182,7 @@ def main(): parser.add_argument( "--dur-file", default=None, type=str, help="path to durations.txt.") - parser.add_argument("--config", type=str, help="transformer config file.") + parser.add_argument("--config", type=str, help="fastspeech2 config file.") parser.add_argument( "--num-cpu", type=int, default=1, help="number of process.")