add style_melgan and hifigan in tts cli, test=tts (#1241)

pull/1243/head
TianYuan 4 years ago committed by GitHub
parent a232cd8b12
commit fbe3c05137
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -178,6 +178,32 @@ pretrained_models = {
'speech_stats': 'speech_stats':
'feats_stats.npy', 'feats_stats.npy',
}, },
# style_melgan
"style_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
} }
model_alias = { model_alias = {
@ -199,6 +225,14 @@ model_alias = {
"paddlespeech.t2s.models.melgan:MelGANGenerator", "paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference": "mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference", "paddlespeech.t2s.models.melgan:MelGANInference",
"style_melgan":
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
"style_melgan_inference":
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
} }
@ -266,7 +300,7 @@ class TTSExecutor(BaseExecutor):
default='pwgan_csmsc', default='pwgan_csmsc',
choices=[ choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc' 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc'
], ],
help='Choose vocoder type of tts task.') help='Choose vocoder type of tts task.')
@ -504,37 +538,47 @@ class TTSExecutor(BaseExecutor):
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False get_tone_ids = False
merge_sentences = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if lang == 'zh': if lang == 'zh':
input_ids = self.frontend.get_input_ids( input_ids = self.frontend.get_input_ids(
text, merge_sentences=True, get_tone_ids=get_tone_ids) text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
phone_ids = phone_ids[0]
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
tone_ids = tone_ids[0]
elif lang == 'en': elif lang == 'en':
input_ids = self.frontend.get_input_ids(text) input_ids = self.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en'}!")
# am flags = 0
if am_name == 'speedyspeech': for i in range(len(phone_ids)):
mel = self.am_inference(phone_ids, tone_ids) part_phone_ids = phone_ids[i]
# fastspeech2 # am
else: if am_name == 'speedyspeech':
# multi speaker part_tone_ids = tone_ids[i]
if am_dataset in {"aishell3", "vctk"}: mel = self.am_inference(part_phone_ids, part_tone_ids)
mel = self.am_inference( # fastspeech2
phone_ids, spk_id=paddle.to_tensor(spk_id))
else: else:
mel = self.am_inference(phone_ids) # multi speaker
if am_dataset in {"aishell3", "vctk"}:
# voc mel = self.am_inference(
wav = self.voc_inference(mel) part_phone_ids, spk_id=paddle.to_tensor(spk_id))
self._outputs['wav'] = wav else:
mel = self.am_inference(part_phone_ids)
# voc
wav = self.voc_inference(mel)
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = paddle.concat([wav_all, wav])
self._outputs['wav'] = wav_all
def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]: def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]:
""" """

@ -196,41 +196,47 @@ def evaluate(args):
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
get_tone_ids = False get_tone_ids = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if args.lang == 'zh': if args.lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
sentence, merge_sentences=True, get_tone_ids=get_tone_ids) sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
phone_ids = phone_ids[0]
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
tone_ids = tone_ids[0]
elif args.lang == 'en': elif args.lang == 'en':
input_ids = frontend.get_input_ids(sentence) input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en'}!")
with paddle.no_grad(): with paddle.no_grad():
# acoustic model flags = 0
if am_name == 'fastspeech2': for i in range(len(phone_ids)):
# multi speaker part_phone_ids = phone_ids[i]
if am_dataset in {"aishell3", "vctk"}: # acoustic model
spk_id = paddle.to_tensor(args.spk_id) if am_name == 'fastspeech2':
mel = am_inference(phone_ids, spk_id) # multi speaker
if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, spk_id)
else:
mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = am_inference(part_phone_ids, part_tone_ids)
# vocoder
wav = voc_inference(mel)
if flags == 0:
wav_all = wav
flags = 1
else: else:
mel = am_inference(phone_ids) wav_all = paddle.concat([wav_all, wav])
elif am_name == 'speedyspeech':
mel = am_inference(phone_ids, tone_ids)
# vocoder
wav = voc_inference(mel)
sf.write( sf.write(
str(output_dir / (utt_id + ".wav")), str(output_dir / (utt_id + ".wav")),
wav.numpy(), wav_all.numpy(),
samplerate=am_config.fs) samplerate=am_config.fs)
print(f"{utt_id} done!") print(f"{utt_id} done!")

@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from typing import List
import numpy as np
import paddle import paddle
from g2p_en import G2p from g2p_en import G2p
from g2pM import G2pM from g2pM import G2pM
@ -21,6 +23,7 @@ from g2pM import G2pM
from paddlespeech.t2s.frontend.normalizer.normalizer import normalize from paddlespeech.t2s.frontend.normalizer.normalizer import normalize
from paddlespeech.t2s.frontend.punctuation import get_punctuations from paddlespeech.t2s.frontend.punctuation import get_punctuations
from paddlespeech.t2s.frontend.vocab import Vocab from paddlespeech.t2s.frontend.vocab import Vocab
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
# discard opencc untill we find an easy solution to install it on windows # discard opencc untill we find an easy solution to install it on windows
# from opencc import OpenCC # from opencc import OpenCC
@ -53,6 +56,7 @@ class English(Phonetics):
self.vocab = Vocab(self.phonemes + self.punctuations) self.vocab = Vocab(self.phonemes + self.punctuations)
self.vocab_phones = {} self.vocab_phones = {}
self.punc = ":,;。?!“”‘’':,;.?!" self.punc = ":,;。?!“”‘’':,;.?!"
self.text_normalizer = TextNormalizer()
if phone_vocab_path: if phone_vocab_path:
with open(phone_vocab_path, 'rt') as f: with open(phone_vocab_path, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
@ -78,19 +82,42 @@ class English(Phonetics):
phonemes = [item for item in phonemes if item in self.vocab.stoi] phonemes = [item for item in phonemes if item in self.vocab.stoi]
return phonemes return phonemes
def get_input_ids(self, sentence: str) -> paddle.Tensor: def _p2id(self, phonemes: List[str]) -> np.array:
result = {} # replace unk phone with sp
phones = self.phoneticize(sentence) phonemes = [
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones = [
phn if (phn in self.vocab_phones and phn not in self.punc) else "sp" phn if (phn in self.vocab_phones and phn not in self.punc) else "sp"
for phn in phones for phn in phonemes
] ]
phone_ids = [self.vocab_phones[phn] for phn in phones] phone_ids = [self.vocab_phones[item] for item in phonemes]
phone_ids = paddle.to_tensor(phone_ids) return np.array(phone_ids, np.int64)
result["phone_ids"] = phone_ids
def get_input_ids(self, sentence: str,
merge_sentences: bool=False) -> paddle.Tensor:
result = {}
sentences = self.text_normalizer._split(sentence, lang="en")
phones_list = []
temp_phone_ids = []
for sentence in sentences:
phones = self.phoneticize(sentence)
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones_list.append(phones)
if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
for part_phones_list in phones_list:
phone_ids = self._p2id(part_phones_list)
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
result["phone_ids"] = temp_phone_ids
return result return result
def numericalize(self, phonemes): def numericalize(self, phonemes):

@ -53,7 +53,7 @@ class TextNormalizer():
def __init__(self): def __init__(self):
self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)') self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)')
def _split(self, text: str) -> List[str]: def _split(self, text: str, lang="zh") -> List[str]:
"""Split long text into sentences with sentence-splitting punctuations. """Split long text into sentences with sentence-splitting punctuations.
Parameters Parameters
---------- ----------
@ -65,7 +65,8 @@ class TextNormalizer():
Sentences. Sentences.
""" """
# Only for pure Chinese here # Only for pure Chinese here
text = text.replace(" ", "") if lang == "zh":
text = text.replace(" ", "")
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip() text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]

Loading…
Cancel
Save