more comment on tts frontend

pull/3316/head
Hui Zhang 1 year ago
parent 8aa9790c75
commit 42f2186d71

@ -99,14 +99,23 @@ def norm(data, mean, std):
return (data - mean) / std
def get_chunks(data, block_size: int, pad_size: int):
data_len = data.shape[1]
def get_chunks(mel, chunk_size: int, pad_size: int):
"""
Split mel by chunk size with left and right context.
Args:
mel (paddle.Tensor): mel spectrogram, shape (B, T, D)
chunk_size (int): chunk size
pad_size (int): size for left and right context.
"""
T = mel.shape[1]
n = math.ceil(T / chunk_size)
chunks = []
n = math.ceil(data_len / block_size)
for i in range(n):
start = max(0, i * block_size - pad_size)
end = min((i + 1) * block_size + pad_size, data_len)
chunks.append(data[:, start:end, :])
start = max(0, i * chunk_size - pad_size)
end = min((i + 1) * chunk_size + pad_size, T)
chunks.append(mel[:, start:end, :])
return chunks
@ -117,14 +126,10 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
with open(text_file, 'rt', encoding='utf-8') as f:
for line in f:
if line.strip() != "":
items = re.split(r"\s+", line.strip(), 1)
items = re.split(r"\s+", line.strip(), maxsplit=1)
assert len(items) == 2
utt_id = items[0]
if lang in {'zh', 'canton'}:
sentence = "".join(items[1:])
elif lang == 'en':
sentence = " ".join(items[1:])
elif lang == 'mix':
sentence = " ".join(items[1:])
sentence = items[1]
sentences.append((utt_id, sentence))
return sentences
@ -319,6 +324,7 @@ def run_frontend(
input_ids = {}
if text.strip() != "" and re.match(r".*?<speak>.*?</speak>.*", text,
re.DOTALL):
# using ssml
input_ids = frontend.get_input_ids_ssml(
text,
merge_sentences=merge_sentences,
@ -359,6 +365,7 @@ def run_frontend(
outs.update({'is_slurs': is_slurs})
else:
print("lang should in {'zh', 'en', 'mix', 'canton', 'sing'}!")
outs.update({'phone_ids': phone_ids})
return outs

@ -13,6 +13,7 @@
# limitations under the License.
import argparse
from pathlib import Path
from pprint import pprint
import paddle
import soundfile as sf
@ -78,6 +79,7 @@ def evaluate(args):
# whether dygraph to static
if args.inference_dir:
print("convert am and voc to static model.")
# acoustic model
am_inference = am_to_static(
am_inference=am_inference,
@ -92,6 +94,7 @@ def evaluate(args):
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False
# Avoid not stopping at the end of a sub sentence when tacotron2_ljspeech dygraph to static graph
# but still not stopping in the end (NOTE by yuantian01 Feb 9 2022)
@ -102,12 +105,18 @@ def evaluate(args):
if am_name == 'speedyspeech':
get_tone_ids = True
# wav samples
N = 0
# inference time cost
T = 0
# [(uid, text), ]
if am_name == 'diffsinger':
sentences = get_sentences_svs(text_file=args.text)
else:
sentences = get_sentences(text_file=args.text, lang=args.lang)
pprint(f"inputs: {sentences}")
for utt_id, sentence in sentences:
with timer() as t:
if am_name == "diffsinger":
@ -116,6 +125,8 @@ def evaluate(args):
else:
text = sentence
svs_input = None
# frontend
frontend_dict = run_frontend(
frontend=frontend,
text=text,
@ -124,25 +135,33 @@ def evaluate(args):
lang=args.lang,
svs_input=svs_input)
phone_ids = frontend_dict['phone_ids']
# pprint(f"process: {utt_id} {phone_ids}")
with paddle.no_grad():
flags = 0
for i in range(len(phone_ids)):
# sub phone, split by `sp` or punctuation.
part_phone_ids = phone_ids[i]
# acoustic model
if am_name == 'fastspeech2':
# multi speaker
if am_dataset in {"aishell3", "vctk", "mix", "canton"}:
# multi-speaker
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, spk_id)
else:
# single-speaker
mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech':
part_tone_ids = frontend_dict['tone_ids'][i]
if am_dataset in {"aishell3", "vctk", "mix"}:
# multi-speaker
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, part_tone_ids,
spk_id)
else:
# single-speaker
mel = am_inference(part_phone_ids, part_tone_ids)
elif am_name == 'tacotron2':
mel = am_inference(part_phone_ids)
@ -155,6 +174,7 @@ def evaluate(args):
note=part_note_ids,
note_dur=part_note_durs,
is_slur=part_is_slurs, )
# vocoder
wav = voc_inference(mel)
if flags == 0:
@ -162,17 +182,23 @@ def evaluate(args):
flags = 1
else:
wav_all = paddle.concat([wav_all, wav])
wav = wav_all.numpy()
N += wav.size
T += t.elapse
# samples per second
speed = wav.size / t.elapse
# generate one second wav need `RTF` seconds
rtf = am_config.fs / speed
print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
sf.write(
str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs)
print(f"{utt_id} done!")
print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlespeech.t2s.frontend.phonectic import Phonetics
"""
A phonology system with ARPABET symbols and limited punctuations. The G2P
conversion is done by g2p_en.
@ -19,55 +18,68 @@ conversion is done by g2p_en.
Note that g2p_en does not handle words with hypen well. So make sure the input
sentence is first normalized.
"""
from paddlespeech.t2s.frontend.vocab import Vocab
from g2p_en import G2p
from paddlespeech.t2s.frontend.phonectic import Phonetics
from paddlespeech.t2s.frontend.vocab import Vocab
class ARPABET(Phonetics):
"""A phonology for English that uses ARPABET as the phoneme vocabulary.
"""A phonology for English that uses ARPABET without stress as the phoneme vocabulary.
47 symbols = 39 phones + 4 punctuations + 4 special tokens(<pad> <unk> <s> </s>)
The current phoneme set contains 39 phonemes, vowels carry a lexical stress marker:
0 No stress
1 Primary stress
2 Secondary stress
Phoneme Set:
Phoneme Example Translation
------- ------- -----------
AA odd AA D
AE at AE T
AH hut HH AH T
AO ought AO T
AW cow K AW
AY hide HH AY D
B be B IY
CH cheese CH IY Z
D dee D IY
DH thee DH IY
EH Ed EH D
ER hurt HH ER T
EY ate EY T
F fee F IY
G green G R IY N
HH he HH IY
IH it IH T
IY eat IY T
JH gee JH IY
K key K IY
L lee L IY
M me M IY
N knee N IY
NG ping P IH NG
OW oat OW T
OY toy T OY
P pee P IY
R read R IY D
S sea S IY
SH she SH IY
T tea T IY
TH theta TH EY T AH
UH hood HH UH D
UW two T UW
V vee V IY
W we W IY
Y yield Y IY L D
Z zee Z IY
ZH seizure S IY ZH ER
See http://www.speech.cs.cmu.edu/cgi-bin/cmudict for more details.
Phoneme Example Translation
------- ------- -----------
AA odd AA D
AE at AE T
AH hut HH AH T
AO ought AO T
AW cow K AW
AY hide HH AY D
B be B IY
CH cheese CH IY Z
D dee D IY
DH thee DH IY
EH Ed EH D
ER hurt HH ER T
EY ate EY T
F fee F IY
G green G R IY N
HH he HH IY
IH it IH T
IY eat IY T
JH gee JH IY
K key K IY
L lee L IY
M me M IY
N knee N IY
NG ping P IH NG
OW oat OW T
OY toy T OY
P pee P IY
R read R IY D
S sea S IY
SH she SH IY
T tea T IY
TH theta TH EY T AH
UH hood HH UH D
UW two T UW
V vee V IY
W we W IY
Y yield Y IY L D
Z zee Z IY
ZH seizure S IY ZH ER
"""
# 39 phonemes
phonemes = [
'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER',
'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
@ -76,6 +88,8 @@ class ARPABET(Phonetics):
]
punctuations = [',', '.', '?', '!']
symbols = phonemes + punctuations
# vowels carry a lexical stress marker
# 0 unstressed无重音, 1 primary stress主重音和 2 secondary stress次重音
_stress_to_no_stress_ = {
'AA0': 'AA',
'AA1': 'AA',
@ -124,7 +138,12 @@ class ARPABET(Phonetics):
'UW2': 'UW'
}
def __repr__(self):
fmt = "ARPABETWithoutStress(phonemes: {}, punctuations: {})"
return fmt.format(len(phonemes), punctuations)
def __init__(self):
# https://github.com/Kyubyong/g2p/blob/master/g2p_en/g2p.py
self.backend = G2p()
self.vocab = Vocab(self.phonemes + self.punctuations)
@ -139,6 +158,7 @@ class ARPABET(Phonetics):
Returns:
List[str]: The list of pronunciation sequence.
"""
# g2p and remove vowel stress
phonemes = [
self._remove_vowels(item) for item in self.backend(sentence)
]
@ -158,6 +178,7 @@ class ARPABET(Phonetics):
Returns:
List[int]: The list of pronunciation id sequence.
"""
# phonemes to ids
ids = [self.vocab.lookup(item) for item in phonemes]
return ids
@ -189,11 +210,16 @@ class ARPABET(Phonetics):
def vocab_size(self):
""" Vocab size.
"""
# 47 = 39 phones + 4 punctuations + 4 special tokens
# 47 = 39 phones + 4 punctuations + 4 special tokens(<pad> <unk> <s> </s>)
return len(self.vocab)
class ARPABETWithStress(Phonetics):
"""
A phonology for English that uses ARPABET with stress as the phoneme vocabulary.
77 symbols = 69 phones + 4 punctuations + 4 special tokens
"""
phonemes = [
'AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D',
@ -206,6 +232,10 @@ class ARPABETWithStress(Phonetics):
punctuations = [',', '.', '?', '!']
symbols = phonemes + punctuations
def __repr__(self):
fmt = "ARPABETWithStress(phonemes: {}, punctuations: {})"
return fmt.format(len(phonemes), punctuations)
def __init__(self):
self.backend = G2p()
self.vocab = Vocab(self.phonemes + self.punctuations)

@ -47,4 +47,5 @@ polyphonic:
恶行: ['e4','xing2']
: ['ai4']
扎实: ['zha1','shi2']
干将: ['gan4','jiang4']
干将: ['gan4','jiang4']
陈威行: ['chen2', 'wei1', 'hang2']

@ -90,13 +90,14 @@ class MixTextProcessor():
dom = DomXml(in_xml)
tags = dom.get_text_and_sayas_tags()
ctlist.extend(tags)
ctlist.append(after_xml)
return ctlist
else:
ctlist.append(mixstr)
return ctlist
class DomXml():
def __init__(self, xmlstr):
self.tdom = parseString(xmlstr) #Document

@ -20,6 +20,9 @@ from pypinyin import Style
class ToneSandhi():
def __repr__(self):
return "MandarinToneSandhi"
def __init__(self):
self.must_neural_tone_words = {
'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝',
@ -69,6 +72,19 @@ class ToneSandhi():
}
self.punc = ":,;。?!“”‘’':,;.?!"
def _split_word(self, word: str) -> List[str]:
word_list = jieba.cut_for_search(word)
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
first_subword = word_list[0]
first_begin_idx = word.find(first_subword)
if first_begin_idx == 0:
second_subword = word[len(first_subword):]
new_word_list = [first_subword, second_subword]
else:
second_subword = word[:-len(first_subword)]
new_word_list = [second_subword, first_subword]
return new_word_list
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
# e.g.
# word: "家里"
@ -154,18 +170,8 @@ class ToneSandhi():
finals[i] = finals[i][:-1] + "4"
return finals
def _split_word(self, word: str) -> List[str]:
word_list = jieba.cut_for_search(word)
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
first_subword = word_list[0]
first_begin_idx = word.find(first_subword)
if first_begin_idx == 0:
second_subword = word[len(first_subword):]
new_word_list = [first_subword, second_subword]
else:
second_subword = word[:-len(first_subword)]
new_word_list = [second_subword, first_subword]
return new_word_list
def _all_tone_three(self, finals: List[str]) -> bool:
return all(x[-1] == "3" for x in finals)
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
@ -207,9 +213,6 @@ class ToneSandhi():
return finals
def _all_tone_three(self, finals: List[str]) -> bool:
return all(x[-1] == "3" for x in finals)
# merge "不" and the word behind it
# if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
@ -336,6 +339,9 @@ class ToneSandhi():
def pre_merge_for_modify(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
seg: [(word, pos), ...]
"""
seg = self._merge_bu(seg)
seg = self._merge_yi(seg)
seg = self._merge_reduplication(seg)
@ -346,7 +352,11 @@ class ToneSandhi():
def modified_tone(self, word: str, pos: str,
finals: List[str]) -> List[str]:
"""
word: 分词
pos: 词性
finals: 带调韵母, [final1, ..., finaln]
"""
finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals)

@ -31,9 +31,9 @@ from pypinyin_dict.phrase_pinyin_data import large_pinyin
from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter
from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon
from paddlespeech.t2s.frontend.rhy_prediction.rhy_predictor import RhyPredictor
from paddlespeech.t2s.frontend.ssml.xml_processor import MixTextProcessor
from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
from paddlespeech.t2s.ssml.xml_processor import MixTextProcessor
INITIALS = [
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh',
@ -49,13 +49,18 @@ def intersperse(lst, item):
def insert_after_character(lst, item):
"""
inset `item` after finals.
"""
result = [item]
for phone in lst:
result.append(phone)
if phone not in INITIALS:
# finals has tones
# assert phone[-1] in "12345"
result.append(item)
return result
@ -85,9 +90,7 @@ class Frontend():
phone_vocab_path=None,
tone_vocab_path=None,
use_rhy=False):
self.mix_ssml_processor = MixTextProcessor()
self.tone_modifier = ToneSandhi()
self.text_normalizer = TextNormalizer()
self.punc = ":,;。?!“”‘’':,;.?!"
self.rhy_phns = ['sp1', 'sp2', 'sp3', 'sp4']
self.phrases_dict = {
@ -108,28 +111,7 @@ class Frontend():
'': [['lei5']],
'掺和': [['chan1'], ['huo5']]
}
self.use_rhy = use_rhy
if use_rhy:
self.rhy_predictor = RhyPredictor()
print("Rhythm predictor loaded.")
# g2p_model can be pypinyin and g2pM and g2pW
self.g2p_model = g2p_model
if self.g2p_model == "g2pM":
self.g2pM_model = G2pM()
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
elif self.g2p_model == "g2pW":
# use pypinyin as backup for non polyphonic characters in g2pW
self._init_pypinyin()
self.corrector = Polyphonic()
self.g2pM_model = G2pM()
self.g2pW_model = G2PWOnnxConverter(
style='pinyin', enable_non_tradional_chinese=True)
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
else:
self._init_pypinyin()
self.must_erhua = {
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
}
@ -154,13 +136,51 @@ class Frontend():
for tone, id in tone_id:
self.vocab_tones[tone] = int(id)
# SSML
self.mix_ssml_processor = MixTextProcessor()
# tone sandhi
self.tone_modifier = ToneSandhi()
# TN
self.text_normalizer = TextNormalizer()
# prosody
self.use_rhy = use_rhy
if use_rhy:
self.rhy_predictor = RhyPredictor()
print("Rhythm predictor loaded.")
# g2p
assert g2p_model in ('pypinyin', 'g2pM', 'g2pW')
self.g2p_model = g2p_model
if self.g2p_model == "g2pM":
self.g2pM_model = G2pM()
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
elif self.g2p_model == "g2pW":
# use pypinyin as backup for non polyphonic characters in g2pW
self._init_pypinyin()
self.corrector = Polyphonic()
self.g2pM_model = G2pM()
self.g2pW_model = G2PWOnnxConverter(
style='pinyin', enable_non_tradional_chinese=True)
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
else:
self._init_pypinyin()
def _init_pypinyin(self):
"""
Load pypinyin G2P module.
"""
large_pinyin.load()
load_phrases_dict(self.phrases_dict)
# 调整字的拼音顺序
load_single_dict({ord(u''): u'de,di4'})
def _get_initials_finals(self, word: str) -> List[List[str]]:
"""
Get word initial and final by pypinyin or g2pM
"""
initials = []
finals = []
if self.g2p_model == "pypinyin":
@ -171,11 +191,14 @@ class Frontend():
for c, v in zip(orig_initials, orig_finals):
if re.match(r'i\d', v):
if c in ['z', 'c', 's']:
# zi, ci, si
v = re.sub('i', 'ii', v)
elif c in ['zh', 'ch', 'sh', 'r']:
# zhi, chi, shi
v = re.sub('i', 'iii', v)
initials.append(c)
finals.append(v)
elif self.g2p_model == "g2pM":
pinyins = self.g2pM_model(word, tone=True, char_split=False)
for pinyin in pinyins:
@ -192,58 +215,123 @@ class Frontend():
# If it's not pinyin (possibly punctuation) or no conversion is required
initials.append(pinyin)
finals.append(pinyin)
return initials, finals
def _merge_erhua(self,
initials: List[str],
finals: List[str],
word: str,
pos: str) -> List[List[str]]:
"""
Do erhub.
"""
# fix er1
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn == 'er1':
finals[i] = 'er2'
# 发音
if word not in self.must_erhua and (word in self.not_erhua or
pos in {"a", "j", "nr"}):
return initials, finals
# "……" 等情况直接返回
if len(finals) != len(word):
return initials, finals
assert len(finals) == len(word)
# 不发音
new_initials = []
new_finals = []
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn in {
"er2", "er5"
} and word[-2:] not in self.not_erhua and new_finals:
new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1]
else:
new_initials.append(initials[i])
new_finals.append(phn)
return new_initials, new_finals
# if merge_sentences, merge all sentences into one phone sequence
def _g2p(self,
sentences: List[str],
merge_sentences: bool=True,
with_erhua: bool=True) -> List[List[str]]:
"""
Return: list of list phonemes.
[['w', 'o3', 'm', 'en2', 'sp'], ...]
"""
segments = sentences
phones_list = []
# split by punctuation
for seg in segments:
if self.use_rhy:
seg = self.rhy_predictor._clean_text(seg)
phones = []
# Replace all English words in the sentence
# remove all English words in the sentence
seg = re.sub('[a-zA-Z]+', '', seg)
# add prosody mark
if self.use_rhy:
seg = self.rhy_predictor.get_prediction(seg)
# [(word, pos), ...]
seg_cut = psg.lcut(seg)
initials = []
finals = []
# fix wordseg bad case for sandhi
seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut)
# 为了多音词获得更好的效果,这里采用整句预测
phones = []
initials = []
finals = []
if self.g2p_model == "g2pW":
try:
# undo prosody
if self.use_rhy:
seg = self.rhy_predictor._clean_text(seg)
# g2p
pinyins = self.g2pW_model(seg)[0]
except Exception:
# g2pW采用模型采用繁体输入如果有cover不了的简体词采用g2pM预测
# g2pW 模型采用繁体输入如果有cover不了的简体词采用g2pM预测
print("[%s] not in g2pW dict,use g2pM" % seg)
pinyins = self.g2pM_model(seg, tone=True, char_split=False)
# do prosody
if self.use_rhy:
rhy_text = self.rhy_predictor.get_prediction(seg)
final_py = self.rhy_predictor.pinyin_align(pinyins,
rhy_text)
pinyins = final_py
pre_word_length = 0
for word, pos in seg_cut:
sub_initials = []
sub_finals = []
now_word_length = pre_word_length + len(word)
# skip english word
if pos == 'eng':
pre_word_length = now_word_length
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
# 矫正发音
# 多音字消歧
word_pinyins = self.corrector.correct_pronunciation(
word, word_pinyins)
for pinyin, char in zip(word_pinyins, word):
if pinyin is None:
pinyin = char
pinyin = pinyin.replace("u:", "v")
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[
pinyin].split(" ")
@ -257,28 +345,41 @@ class Frontend():
# If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin)
sub_finals.append(pinyin)
pre_word_length = now_word_length
# tone sandhi
sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals)
# er hua
if with_erhua:
sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
else:
# pypinyin, g2pM
for word, pos in seg_cut:
if pos == 'eng':
# skip english word
continue
# g2p
sub_initials, sub_finals = self._get_initials_finals(word)
# tone sandhi
sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals)
# er hua
if with_erhua:
sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
# sum(iterable[, start])
initials = sum(initials, [])
finals = sum(finals, [])
@ -287,111 +388,34 @@ class Frontend():
# we discriminate i, ii and iii
if c and c not in self.punc:
phones.append(c)
# replace punctuation by `sp`
if c and c in self.punc:
phones.append('sp')
if v and v not in self.punc and v not in self.rhy_phns:
phones.append(v)
phones_list.append(phones)
if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
return phones_list
def _split_word_to_char(self, words):
res = []
for x in words:
res.append(x)
return res
# if using ssml, have pingyin specified, assign pinyin to words
def _g2p_assign(self,
words: List[str],
pinyin_spec: List[str],
merge_sentences: bool=True) -> List[List[str]]:
phones_list = []
initials = []
finals = []
words = self._split_word_to_char(words[0])
for pinyin, char in zip(pinyin_spec, words):
sub_initials = []
sub_finals = []
pinyin = pinyin.replace("u:", "v")
#self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1])
elif len(initial_final_list) == 1:
sub_initials.append('')
sub_finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin)
sub_finals.append(pinyin)
initials.append(sub_initials)
finals.append(sub_finals)
phones_list.append(phones)
initials = sum(initials, [])
finals = sum(finals, [])
phones = []
for c, v in zip(initials, finals):
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c and c not in self.punc:
phones.append(c)
if c and c in self.punc:
phones.append('sp')
if v and v not in self.punc and v not in self.rhy_phns:
phones.append(v)
phones_list.append(phones)
# merge split sub sentence into one sentence.
if merge_sentences:
# sub sentence phonemes
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
# sentence phonemes
phones_list = []
phones_list.append(merge_list)
return phones_list
def _merge_erhua(self,
initials: List[str],
finals: List[str],
word: str,
pos: str) -> List[List[str]]:
# fix er1
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn == 'er1':
finals[i] = 'er2'
if word not in self.must_erhua and (word in self.not_erhua or
pos in {"a", "j", "nr"}):
return initials, finals
# "……" 等情况直接返回
if len(finals) != len(word):
return initials, finals
assert len(finals) == len(word)
new_initials = []
new_finals = []
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn in {
"er2", "er5"
} and word[-2:] not in self.not_erhua and new_finals:
new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1]
else:
new_finals.append(phn)
new_initials.append(initials[i])
return new_initials, new_finals
return phones_list
def _p2id(self, phonemes: List[str]) -> np.ndarray:
"""
Phoneme to Index
"""
# replace unk phone with sp
phonemes = [
phn if phn in self.vocab_phones else "sp" for phn in phonemes
@ -400,6 +424,9 @@ class Frontend():
return np.array(phone_ids, np.int64)
def _t2id(self, tones: List[str]) -> np.ndarray:
"""
Tone to Index.
"""
# replace unk phone with sp
tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
tone_ids = [self.vocab_tones[item] for item in tones]
@ -407,6 +434,9 @@ class Frontend():
def _get_phone_tone(self, phonemes: List[str],
get_tone_ids: bool=False) -> List[List[str]]:
"""
Get tone from phonemes.
"""
phones = []
tones = []
if get_tone_ids and self.vocab_tones:
@ -423,13 +453,14 @@ class Frontend():
-1] == 'r' and phone not in self.vocab_phones and phone[:
-1] in self.vocab_phones:
phones.append(phone[:-1])
phones.append("er")
tones.append(tone)
phones.append("er")
tones.append("2")
else:
phones.append(phone)
tones.append(tone)
else:
# initals with 0 tone.
phones.append(full_phone)
tones.append('0')
else:
@ -443,6 +474,7 @@ class Frontend():
phones.append("er2")
else:
phones.append(phone)
return phones, tones
def get_phonemes(self,
@ -451,10 +483,16 @@ class Frontend():
with_erhua: bool=True,
robot: bool=False,
print_info: bool=False) -> List[List[str]]:
"""
Main function to do G2P
"""
# TN & Text Segmentation
sentences = self.text_normalizer.normalize(sentence)
# Prosody & WS & g2p & tone sandhi
phonemes = self._g2p(
sentences, merge_sentences=merge_sentences, with_erhua=with_erhua)
# change all tones to `1`
# simulate robot pronunciation, change all tones to `1`
if robot:
new_phonemes = []
for sentence in phonemes:
@ -466,6 +504,7 @@ class Frontend():
new_sentence.append(item)
new_phonemes.append(new_sentence)
phonemes = new_phonemes
if print_info:
print("----------------------------")
print("text norm results:")
@ -476,25 +515,101 @@ class Frontend():
print("----------------------------")
return phonemes
#@an added for ssml pinyin
def _split_word_to_char(self, words):
res = []
for x in words:
res.append(x)
return res
# if using ssml, have pingyin specified, assign pinyin to words
def _g2p_assign(self,
words: List[str],
pinyin_spec: List[str],
merge_sentences: bool=True) -> List[List[str]]:
"""
Replace phoneme by SSML
"""
phones_list = []
initials = []
finals = []
# to charactor list
words = self._split_word_to_char(words[0])
for pinyin, char in zip(pinyin_spec, words):
sub_initials = []
sub_finals = []
pinyin = pinyin.replace("u:", "v")
#self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1])
elif len(initial_final_list) == 1:
sub_initials.append('')
sub_finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin)
sub_finals.append(pinyin)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
phones = []
for c, v in zip(initials, finals):
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c and c not in self.punc:
phones.append(c)
# replace punc to `sp`
if c and c in self.punc:
phones.append('sp')
if v and v not in self.punc and v not in self.rhy_phns:
phones.append(v)
phones_list.append(phones)
if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
return phones_list
def get_phonemes_ssml(self,
ssml_inputs: list,
merge_sentences: bool=True,
with_erhua: bool=True,
robot: bool=False,
print_info: bool=False) -> List[List[str]]:
"""
Main function to do G2P with SSML support.
"""
all_phonemes = []
for word_pinyin_item in ssml_inputs:
phonemes = []
print("ssml inputs:", word_pinyin_item)
sentence, pinyin_spec = itemgetter(0, 1)(word_pinyin_item)
print('ssml g2p:', sentence, pinyin_spec)
# TN & Text Segmentation
sentences = self.text_normalizer.normalize(sentence)
if len(pinyin_spec) == 0:
# g2p word w/o specified <say-as>
phonemes = self._g2p(
sentences,
merge_sentences=merge_sentences,
with_erhua=with_erhua)
else:
# phonemes should be pinyin_spec
# word phonemes specified by <say-as>
phonemes = self._g2p_assign(
sentences, pinyin_spec, merge_sentences=merge_sentences)
@ -523,6 +638,9 @@ class Frontend():
return [sum(all_phonemes, [])]
def add_sp_if_no(self, phonemes):
"""
Prosody mark #4 added at sentence end.
"""
if not phonemes[-1][-1].startswith('sp'):
phonemes[-1].append('sp4')
return phonemes
@ -542,8 +660,11 @@ class Frontend():
merge_sentences=merge_sentences,
print_info=print_info,
robot=robot)
# add #4 for sentence end.
if self.use_rhy:
phonemes = self.add_sp_if_no(phonemes)
result = {}
phones = []
tones = []
@ -551,28 +672,33 @@ class Frontend():
temp_tone_ids = []
for part_phonemes in phonemes:
phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids)
if add_blank:
phones = insert_after_character(phones, blank_token)
if tones:
tone_ids = self._t2id(tones)
if to_tensor:
tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids)
if phones:
phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
if temp_tone_ids:
result["tone_ids"] = temp_tone_ids
if temp_phone_ids:
result["phone_ids"] = temp_phone_ids
return result
# @an added for ssml
def get_input_ids_ssml(
self,
sentence: str,
@ -584,12 +710,15 @@ class Frontend():
blank_token: str="<pad>",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
# split setence by SSML tag.
l_inputs = MixTextProcessor.get_pinyin_split(sentence)
phonemes = self.get_phonemes_ssml(
l_inputs,
merge_sentences=merge_sentences,
print_info=print_info,
robot=robot)
result = {}
phones = []
tones = []
@ -599,21 +728,26 @@ class Frontend():
for part_phonemes in phonemes:
phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids)
if add_blank:
phones = insert_after_character(phones, blank_token)
if tones:
tone_ids = self._t2id(tones)
if to_tensor:
tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids)
if phones:
phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
if temp_tone_ids:
result["tone_ids"] = temp_tone_ids
if temp_phone_ids:
result["phone_ids"] = temp_phone_ids
return result

Loading…
Cancel
Save