more comment on tts frontend

pull/3316/head
Hui Zhang 1 year ago
parent 8aa9790c75
commit 42f2186d71

@ -99,14 +99,23 @@ def norm(data, mean, std):
return (data - mean) / std return (data - mean) / std
def get_chunks(data, block_size: int, pad_size: int): def get_chunks(mel, chunk_size: int, pad_size: int):
data_len = data.shape[1] """
Split mel by chunk size with left and right context.
Args:
mel (paddle.Tensor): mel spectrogram, shape (B, T, D)
chunk_size (int): chunk size
pad_size (int): size for left and right context.
"""
T = mel.shape[1]
n = math.ceil(T / chunk_size)
chunks = [] chunks = []
n = math.ceil(data_len / block_size)
for i in range(n): for i in range(n):
start = max(0, i * block_size - pad_size) start = max(0, i * chunk_size - pad_size)
end = min((i + 1) * block_size + pad_size, data_len) end = min((i + 1) * chunk_size + pad_size, T)
chunks.append(data[:, start:end, :]) chunks.append(mel[:, start:end, :])
return chunks return chunks
@ -117,14 +126,10 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
with open(text_file, 'rt', encoding='utf-8') as f: with open(text_file, 'rt', encoding='utf-8') as f:
for line in f: for line in f:
if line.strip() != "": if line.strip() != "":
items = re.split(r"\s+", line.strip(), 1) items = re.split(r"\s+", line.strip(), maxsplit=1)
assert len(items) == 2
utt_id = items[0] utt_id = items[0]
if lang in {'zh', 'canton'}: sentence = items[1]
sentence = "".join(items[1:])
elif lang == 'en':
sentence = " ".join(items[1:])
elif lang == 'mix':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence)) sentences.append((utt_id, sentence))
return sentences return sentences
@ -319,6 +324,7 @@ def run_frontend(
input_ids = {} input_ids = {}
if text.strip() != "" and re.match(r".*?<speak>.*?</speak>.*", text, if text.strip() != "" and re.match(r".*?<speak>.*?</speak>.*", text,
re.DOTALL): re.DOTALL):
# using ssml
input_ids = frontend.get_input_ids_ssml( input_ids = frontend.get_input_ids_ssml(
text, text,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
@ -359,6 +365,7 @@ def run_frontend(
outs.update({'is_slurs': is_slurs}) outs.update({'is_slurs': is_slurs})
else: else:
print("lang should in {'zh', 'en', 'mix', 'canton', 'sing'}!") print("lang should in {'zh', 'en', 'mix', 'canton', 'sing'}!")
outs.update({'phone_ids': phone_ids}) outs.update({'phone_ids': phone_ids})
return outs return outs

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
from pathlib import Path from pathlib import Path
from pprint import pprint
import paddle import paddle
import soundfile as sf import soundfile as sf
@ -78,6 +79,7 @@ def evaluate(args):
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
print("convert am and voc to static model.")
# acoustic model # acoustic model
am_inference = am_to_static( am_inference = am_to_static(
am_inference=am_inference, am_inference=am_inference,
@ -92,6 +94,7 @@ def evaluate(args):
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False merge_sentences = False
# Avoid not stopping at the end of a sub sentence when tacotron2_ljspeech dygraph to static graph # Avoid not stopping at the end of a sub sentence when tacotron2_ljspeech dygraph to static graph
# but still not stopping in the end (NOTE by yuantian01 Feb 9 2022) # but still not stopping in the end (NOTE by yuantian01 Feb 9 2022)
@ -102,12 +105,18 @@ def evaluate(args):
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
# wav samples
N = 0 N = 0
# inference time cost
T = 0 T = 0
# [(uid, text), ]
if am_name == 'diffsinger': if am_name == 'diffsinger':
sentences = get_sentences_svs(text_file=args.text) sentences = get_sentences_svs(text_file=args.text)
else: else:
sentences = get_sentences(text_file=args.text, lang=args.lang) sentences = get_sentences(text_file=args.text, lang=args.lang)
pprint(f"inputs: {sentences}")
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if am_name == "diffsinger": if am_name == "diffsinger":
@ -116,6 +125,8 @@ def evaluate(args):
else: else:
text = sentence text = sentence
svs_input = None svs_input = None
# frontend
frontend_dict = run_frontend( frontend_dict = run_frontend(
frontend=frontend, frontend=frontend,
text=text, text=text,
@ -124,25 +135,33 @@ def evaluate(args):
lang=args.lang, lang=args.lang,
svs_input=svs_input) svs_input=svs_input)
phone_ids = frontend_dict['phone_ids'] phone_ids = frontend_dict['phone_ids']
# pprint(f"process: {utt_id} {phone_ids}")
with paddle.no_grad(): with paddle.no_grad():
flags = 0 flags = 0
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
# sub phone, split by `sp` or punctuation.
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
# acoustic model # acoustic model
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
# multi speaker # multi speaker
if am_dataset in {"aishell3", "vctk", "mix", "canton"}: if am_dataset in {"aishell3", "vctk", "mix", "canton"}:
# multi-speaker
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, spk_id) mel = am_inference(part_phone_ids, spk_id)
else: else:
# single-speaker
mel = am_inference(part_phone_ids) mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
part_tone_ids = frontend_dict['tone_ids'][i] part_tone_ids = frontend_dict['tone_ids'][i]
if am_dataset in {"aishell3", "vctk", "mix"}: if am_dataset in {"aishell3", "vctk", "mix"}:
# multi-speaker
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, part_tone_ids, mel = am_inference(part_phone_ids, part_tone_ids,
spk_id) spk_id)
else: else:
# single-speaker
mel = am_inference(part_phone_ids, part_tone_ids) mel = am_inference(part_phone_ids, part_tone_ids)
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
mel = am_inference(part_phone_ids) mel = am_inference(part_phone_ids)
@ -155,6 +174,7 @@ def evaluate(args):
note=part_note_ids, note=part_note_ids,
note_dur=part_note_durs, note_dur=part_note_durs,
is_slur=part_is_slurs, ) is_slur=part_is_slurs, )
# vocoder # vocoder
wav = voc_inference(mel) wav = voc_inference(mel)
if flags == 0: if flags == 0:
@ -162,17 +182,23 @@ def evaluate(args):
flags = 1 flags = 1
else: else:
wav_all = paddle.concat([wav_all, wav]) wav_all = paddle.concat([wav_all, wav])
wav = wav_all.numpy() wav = wav_all.numpy()
N += wav.size N += wav.size
T += t.elapse T += t.elapse
# samples per second
speed = wav.size / t.elapse speed = wav.size / t.elapse
# generate one second wav need `RTF` seconds
rtf = am_config.fs / speed rtf = am_config.fs / speed
print( print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
) )
sf.write( sf.write(
str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs) str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs)
print(f"{utt_id} done!") print(f"{utt_id} done!")
print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddlespeech.t2s.frontend.phonectic import Phonetics
""" """
A phonology system with ARPABET symbols and limited punctuations. The G2P A phonology system with ARPABET symbols and limited punctuations. The G2P
conversion is done by g2p_en. conversion is done by g2p_en.
@ -19,13 +18,23 @@ conversion is done by g2p_en.
Note that g2p_en does not handle words with hypen well. So make sure the input Note that g2p_en does not handle words with hypen well. So make sure the input
sentence is first normalized. sentence is first normalized.
""" """
from paddlespeech.t2s.frontend.vocab import Vocab
from g2p_en import G2p from g2p_en import G2p
from paddlespeech.t2s.frontend.phonectic import Phonetics
from paddlespeech.t2s.frontend.vocab import Vocab
class ARPABET(Phonetics): class ARPABET(Phonetics):
"""A phonology for English that uses ARPABET as the phoneme vocabulary. """A phonology for English that uses ARPABET without stress as the phoneme vocabulary.
See http://www.speech.cs.cmu.edu/cgi-bin/cmudict for more details.
47 symbols = 39 phones + 4 punctuations + 4 special tokens(<pad> <unk> <s> </s>)
The current phoneme set contains 39 phonemes, vowels carry a lexical stress marker:
0 No stress
1 Primary stress
2 Secondary stress
Phoneme Set:
Phoneme Example Translation Phoneme Example Translation
------- ------- ----------- ------- ------- -----------
AA odd AA D AA odd AA D
@ -67,7 +76,10 @@ class ARPABET(Phonetics):
Y yield Y IY L D Y yield Y IY L D
Z zee Z IY Z zee Z IY
ZH seizure S IY ZH ER ZH seizure S IY ZH ER
See http://www.speech.cs.cmu.edu/cgi-bin/cmudict for more details.
""" """
# 39 phonemes
phonemes = [ phonemes = [
'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER',
'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
@ -76,6 +88,8 @@ class ARPABET(Phonetics):
] ]
punctuations = [',', '.', '?', '!'] punctuations = [',', '.', '?', '!']
symbols = phonemes + punctuations symbols = phonemes + punctuations
# vowels carry a lexical stress marker
# 0 unstressed无重音, 1 primary stress主重音和 2 secondary stress次重音
_stress_to_no_stress_ = { _stress_to_no_stress_ = {
'AA0': 'AA', 'AA0': 'AA',
'AA1': 'AA', 'AA1': 'AA',
@ -124,7 +138,12 @@ class ARPABET(Phonetics):
'UW2': 'UW' 'UW2': 'UW'
} }
def __repr__(self):
fmt = "ARPABETWithoutStress(phonemes: {}, punctuations: {})"
return fmt.format(len(phonemes), punctuations)
def __init__(self): def __init__(self):
# https://github.com/Kyubyong/g2p/blob/master/g2p_en/g2p.py
self.backend = G2p() self.backend = G2p()
self.vocab = Vocab(self.phonemes + self.punctuations) self.vocab = Vocab(self.phonemes + self.punctuations)
@ -139,6 +158,7 @@ class ARPABET(Phonetics):
Returns: Returns:
List[str]: The list of pronunciation sequence. List[str]: The list of pronunciation sequence.
""" """
# g2p and remove vowel stress
phonemes = [ phonemes = [
self._remove_vowels(item) for item in self.backend(sentence) self._remove_vowels(item) for item in self.backend(sentence)
] ]
@ -158,6 +178,7 @@ class ARPABET(Phonetics):
Returns: Returns:
List[int]: The list of pronunciation id sequence. List[int]: The list of pronunciation id sequence.
""" """
# phonemes to ids
ids = [self.vocab.lookup(item) for item in phonemes] ids = [self.vocab.lookup(item) for item in phonemes]
return ids return ids
@ -189,11 +210,16 @@ class ARPABET(Phonetics):
def vocab_size(self): def vocab_size(self):
""" Vocab size. """ Vocab size.
""" """
# 47 = 39 phones + 4 punctuations + 4 special tokens # 47 = 39 phones + 4 punctuations + 4 special tokens(<pad> <unk> <s> </s>)
return len(self.vocab) return len(self.vocab)
class ARPABETWithStress(Phonetics): class ARPABETWithStress(Phonetics):
"""
A phonology for English that uses ARPABET with stress as the phoneme vocabulary.
77 symbols = 69 phones + 4 punctuations + 4 special tokens
"""
phonemes = [ phonemes = [
'AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', 'AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D',
@ -206,6 +232,10 @@ class ARPABETWithStress(Phonetics):
punctuations = [',', '.', '?', '!'] punctuations = [',', '.', '?', '!']
symbols = phonemes + punctuations symbols = phonemes + punctuations
def __repr__(self):
fmt = "ARPABETWithStress(phonemes: {}, punctuations: {})"
return fmt.format(len(phonemes), punctuations)
def __init__(self): def __init__(self):
self.backend = G2p() self.backend = G2p()
self.vocab = Vocab(self.phonemes + self.punctuations) self.vocab = Vocab(self.phonemes + self.punctuations)

@ -48,3 +48,4 @@ polyphonic:
: ['ai4'] : ['ai4']
扎实: ['zha1','shi2'] 扎实: ['zha1','shi2']
干将: ['gan4','jiang4'] 干将: ['gan4','jiang4']
陈威行: ['chen2', 'wei1', 'hang2']

@ -97,6 +97,7 @@ class MixTextProcessor():
ctlist.append(mixstr) ctlist.append(mixstr)
return ctlist return ctlist
class DomXml(): class DomXml():
def __init__(self, xmlstr): def __init__(self, xmlstr):
self.tdom = parseString(xmlstr) #Document self.tdom = parseString(xmlstr) #Document

@ -20,6 +20,9 @@ from pypinyin import Style
class ToneSandhi(): class ToneSandhi():
def __repr__(self):
return "MandarinToneSandhi"
def __init__(self): def __init__(self):
self.must_neural_tone_words = { self.must_neural_tone_words = {
'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', '麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝',
@ -69,6 +72,19 @@ class ToneSandhi():
} }
self.punc = ":,;。?!“”‘’':,;.?!" self.punc = ":,;。?!“”‘’':,;.?!"
def _split_word(self, word: str) -> List[str]:
word_list = jieba.cut_for_search(word)
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
first_subword = word_list[0]
first_begin_idx = word.find(first_subword)
if first_begin_idx == 0:
second_subword = word[len(first_subword):]
new_word_list = [first_subword, second_subword]
else:
second_subword = word[:-len(first_subword)]
new_word_list = [second_subword, first_subword]
return new_word_list
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041 # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
# e.g. # e.g.
# word: "家里" # word: "家里"
@ -154,18 +170,8 @@ class ToneSandhi():
finals[i] = finals[i][:-1] + "4" finals[i] = finals[i][:-1] + "4"
return finals return finals
def _split_word(self, word: str) -> List[str]: def _all_tone_three(self, finals: List[str]) -> bool:
word_list = jieba.cut_for_search(word) return all(x[-1] == "3" for x in finals)
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
first_subword = word_list[0]
first_begin_idx = word.find(first_subword)
if first_begin_idx == 0:
second_subword = word[len(first_subword):]
new_word_list = [first_subword, second_subword]
else:
second_subword = word[:-len(first_subword)]
new_word_list = [second_subword, first_subword]
return new_word_list
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
@ -207,9 +213,6 @@ class ToneSandhi():
return finals return finals
def _all_tone_three(self, finals: List[str]) -> bool:
return all(x[-1] == "3" for x in finals)
# merge "不" and the word behind it # merge "不" and the word behind it
# if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
@ -336,6 +339,9 @@ class ToneSandhi():
def pre_merge_for_modify( def pre_merge_for_modify(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
seg: [(word, pos), ...]
"""
seg = self._merge_bu(seg) seg = self._merge_bu(seg)
seg = self._merge_yi(seg) seg = self._merge_yi(seg)
seg = self._merge_reduplication(seg) seg = self._merge_reduplication(seg)
@ -346,7 +352,11 @@ class ToneSandhi():
def modified_tone(self, word: str, pos: str, def modified_tone(self, word: str, pos: str,
finals: List[str]) -> List[str]: finals: List[str]) -> List[str]:
"""
word: 分词
pos: 词性
finals: 带调韵母, [final1, ..., finaln]
"""
finals = self._bu_sandhi(word, finals) finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals) finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals) finals = self._neural_sandhi(word, pos, finals)

@ -31,9 +31,9 @@ from pypinyin_dict.phrase_pinyin_data import large_pinyin
from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter
from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon
from paddlespeech.t2s.frontend.rhy_prediction.rhy_predictor import RhyPredictor from paddlespeech.t2s.frontend.rhy_prediction.rhy_predictor import RhyPredictor
from paddlespeech.t2s.frontend.ssml.xml_processor import MixTextProcessor
from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
from paddlespeech.t2s.ssml.xml_processor import MixTextProcessor
INITIALS = [ INITIALS = [
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh',
@ -49,13 +49,18 @@ def intersperse(lst, item):
def insert_after_character(lst, item): def insert_after_character(lst, item):
"""
inset `item` after finals.
"""
result = [item] result = [item]
for phone in lst: for phone in lst:
result.append(phone) result.append(phone)
if phone not in INITIALS: if phone not in INITIALS:
# finals has tones # finals has tones
# assert phone[-1] in "12345" # assert phone[-1] in "12345"
result.append(item) result.append(item)
return result return result
@ -85,9 +90,7 @@ class Frontend():
phone_vocab_path=None, phone_vocab_path=None,
tone_vocab_path=None, tone_vocab_path=None,
use_rhy=False): use_rhy=False):
self.mix_ssml_processor = MixTextProcessor()
self.tone_modifier = ToneSandhi()
self.text_normalizer = TextNormalizer()
self.punc = ":,;。?!“”‘’':,;.?!" self.punc = ":,;。?!“”‘’':,;.?!"
self.rhy_phns = ['sp1', 'sp2', 'sp3', 'sp4'] self.rhy_phns = ['sp1', 'sp2', 'sp3', 'sp4']
self.phrases_dict = { self.phrases_dict = {
@ -108,28 +111,7 @@ class Frontend():
'': [['lei5']], '': [['lei5']],
'掺和': [['chan1'], ['huo5']] '掺和': [['chan1'], ['huo5']]
} }
self.use_rhy = use_rhy
if use_rhy:
self.rhy_predictor = RhyPredictor()
print("Rhythm predictor loaded.")
# g2p_model can be pypinyin and g2pM and g2pW
self.g2p_model = g2p_model
if self.g2p_model == "g2pM":
self.g2pM_model = G2pM()
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
elif self.g2p_model == "g2pW":
# use pypinyin as backup for non polyphonic characters in g2pW
self._init_pypinyin()
self.corrector = Polyphonic()
self.g2pM_model = G2pM()
self.g2pW_model = G2PWOnnxConverter(
style='pinyin', enable_non_tradional_chinese=True)
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
else:
self._init_pypinyin()
self.must_erhua = { self.must_erhua = {
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿" "小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
} }
@ -154,13 +136,51 @@ class Frontend():
for tone, id in tone_id: for tone, id in tone_id:
self.vocab_tones[tone] = int(id) self.vocab_tones[tone] = int(id)
# SSML
self.mix_ssml_processor = MixTextProcessor()
# tone sandhi
self.tone_modifier = ToneSandhi()
# TN
self.text_normalizer = TextNormalizer()
# prosody
self.use_rhy = use_rhy
if use_rhy:
self.rhy_predictor = RhyPredictor()
print("Rhythm predictor loaded.")
# g2p
assert g2p_model in ('pypinyin', 'g2pM', 'g2pW')
self.g2p_model = g2p_model
if self.g2p_model == "g2pM":
self.g2pM_model = G2pM()
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
elif self.g2p_model == "g2pW":
# use pypinyin as backup for non polyphonic characters in g2pW
self._init_pypinyin()
self.corrector = Polyphonic()
self.g2pM_model = G2pM()
self.g2pW_model = G2PWOnnxConverter(
style='pinyin', enable_non_tradional_chinese=True)
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
else:
self._init_pypinyin()
def _init_pypinyin(self): def _init_pypinyin(self):
"""
Load pypinyin G2P module.
"""
large_pinyin.load() large_pinyin.load()
load_phrases_dict(self.phrases_dict) load_phrases_dict(self.phrases_dict)
# 调整字的拼音顺序 # 调整字的拼音顺序
load_single_dict({ord(u''): u'de,di4'}) load_single_dict({ord(u''): u'de,di4'})
def _get_initials_finals(self, word: str) -> List[List[str]]: def _get_initials_finals(self, word: str) -> List[List[str]]:
"""
Get word initial and final by pypinyin or g2pM
"""
initials = [] initials = []
finals = [] finals = []
if self.g2p_model == "pypinyin": if self.g2p_model == "pypinyin":
@ -171,11 +191,14 @@ class Frontend():
for c, v in zip(orig_initials, orig_finals): for c, v in zip(orig_initials, orig_finals):
if re.match(r'i\d', v): if re.match(r'i\d', v):
if c in ['z', 'c', 's']: if c in ['z', 'c', 's']:
# zi, ci, si
v = re.sub('i', 'ii', v) v = re.sub('i', 'ii', v)
elif c in ['zh', 'ch', 'sh', 'r']: elif c in ['zh', 'ch', 'sh', 'r']:
# zhi, chi, shi
v = re.sub('i', 'iii', v) v = re.sub('i', 'iii', v)
initials.append(c) initials.append(c)
finals.append(v) finals.append(v)
elif self.g2p_model == "g2pM": elif self.g2p_model == "g2pM":
pinyins = self.g2pM_model(word, tone=True, char_split=False) pinyins = self.g2pM_model(word, tone=True, char_split=False)
for pinyin in pinyins: for pinyin in pinyins:
@ -192,58 +215,123 @@ class Frontend():
# If it's not pinyin (possibly punctuation) or no conversion is required # If it's not pinyin (possibly punctuation) or no conversion is required
initials.append(pinyin) initials.append(pinyin)
finals.append(pinyin) finals.append(pinyin)
return initials, finals
def _merge_erhua(self,
initials: List[str],
finals: List[str],
word: str,
pos: str) -> List[List[str]]:
"""
Do erhub.
"""
# fix er1
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn == 'er1':
finals[i] = 'er2'
# 发音
if word not in self.must_erhua and (word in self.not_erhua or
pos in {"a", "j", "nr"}):
return initials, finals return initials, finals
# "……" 等情况直接返回
if len(finals) != len(word):
return initials, finals
assert len(finals) == len(word)
# 不发音
new_initials = []
new_finals = []
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn in {
"er2", "er5"
} and word[-2:] not in self.not_erhua and new_finals:
new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1]
else:
new_initials.append(initials[i])
new_finals.append(phn)
return new_initials, new_finals
# if merge_sentences, merge all sentences into one phone sequence # if merge_sentences, merge all sentences into one phone sequence
def _g2p(self, def _g2p(self,
sentences: List[str], sentences: List[str],
merge_sentences: bool=True, merge_sentences: bool=True,
with_erhua: bool=True) -> List[List[str]]: with_erhua: bool=True) -> List[List[str]]:
"""
Return: list of list phonemes.
[['w', 'o3', 'm', 'en2', 'sp'], ...]
"""
segments = sentences segments = sentences
phones_list = [] phones_list = []
# split by punctuation
for seg in segments: for seg in segments:
if self.use_rhy: if self.use_rhy:
seg = self.rhy_predictor._clean_text(seg) seg = self.rhy_predictor._clean_text(seg)
phones = []
# Replace all English words in the sentence # remove all English words in the sentence
seg = re.sub('[a-zA-Z]+', '', seg) seg = re.sub('[a-zA-Z]+', '', seg)
# add prosody mark
if self.use_rhy: if self.use_rhy:
seg = self.rhy_predictor.get_prediction(seg) seg = self.rhy_predictor.get_prediction(seg)
# [(word, pos), ...]
seg_cut = psg.lcut(seg) seg_cut = psg.lcut(seg)
initials = [] # fix wordseg bad case for sandhi
finals = []
seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut) seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut)
# 为了多音词获得更好的效果,这里采用整句预测 # 为了多音词获得更好的效果,这里采用整句预测
phones = []
initials = []
finals = []
if self.g2p_model == "g2pW": if self.g2p_model == "g2pW":
try: try:
# undo prosody
if self.use_rhy: if self.use_rhy:
seg = self.rhy_predictor._clean_text(seg) seg = self.rhy_predictor._clean_text(seg)
# g2p
pinyins = self.g2pW_model(seg)[0] pinyins = self.g2pW_model(seg)[0]
except Exception: except Exception:
# g2pW采用模型采用繁体输入如果有cover不了的简体词采用g2pM预测 # g2pW 模型采用繁体输入如果有cover不了的简体词采用g2pM预测
print("[%s] not in g2pW dict,use g2pM" % seg) print("[%s] not in g2pW dict,use g2pM" % seg)
pinyins = self.g2pM_model(seg, tone=True, char_split=False) pinyins = self.g2pM_model(seg, tone=True, char_split=False)
# do prosody
if self.use_rhy: if self.use_rhy:
rhy_text = self.rhy_predictor.get_prediction(seg) rhy_text = self.rhy_predictor.get_prediction(seg)
final_py = self.rhy_predictor.pinyin_align(pinyins, final_py = self.rhy_predictor.pinyin_align(pinyins,
rhy_text) rhy_text)
pinyins = final_py pinyins = final_py
pre_word_length = 0 pre_word_length = 0
for word, pos in seg_cut: for word, pos in seg_cut:
sub_initials = [] sub_initials = []
sub_finals = [] sub_finals = []
now_word_length = pre_word_length + len(word) now_word_length = pre_word_length + len(word)
# skip english word
if pos == 'eng': if pos == 'eng':
pre_word_length = now_word_length pre_word_length = now_word_length
continue continue
word_pinyins = pinyins[pre_word_length:now_word_length] word_pinyins = pinyins[pre_word_length:now_word_length]
# 矫正发音
# 多音字消歧
word_pinyins = self.corrector.correct_pronunciation( word_pinyins = self.corrector.correct_pronunciation(
word, word_pinyins) word, word_pinyins)
for pinyin, char in zip(word_pinyins, word): for pinyin, char in zip(word_pinyins, word):
if pinyin is None: if pinyin is None:
pinyin = char pinyin = char
pinyin = pinyin.replace("u:", "v") pinyin = pinyin.replace("u:", "v")
if pinyin in self.pinyin2phone: if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[ initial_final_list = self.pinyin2phone[
pinyin].split(" ") pinyin].split(" ")
@ -257,28 +345,41 @@ class Frontend():
# If it's not pinyin (possibly punctuation) or no conversion is required # If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin) sub_initials.append(pinyin)
sub_finals.append(pinyin) sub_finals.append(pinyin)
pre_word_length = now_word_length pre_word_length = now_word_length
# tone sandhi
sub_finals = self.tone_modifier.modified_tone(word, pos, sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals) sub_finals)
# er hua
if with_erhua: if with_erhua:
sub_initials, sub_finals = self._merge_erhua( sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos) sub_initials, sub_finals, word, pos)
initials.append(sub_initials) initials.append(sub_initials)
finals.append(sub_finals) finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word) # assert len(sub_initials) == len(sub_finals) == len(word)
else: else:
# pypinyin, g2pM
for word, pos in seg_cut: for word, pos in seg_cut:
if pos == 'eng': if pos == 'eng':
# skip english word
continue continue
# g2p
sub_initials, sub_finals = self._get_initials_finals(word) sub_initials, sub_finals = self._get_initials_finals(word)
# tone sandhi
sub_finals = self.tone_modifier.modified_tone(word, pos, sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals) sub_finals)
# er hua
if with_erhua: if with_erhua:
sub_initials, sub_finals = self._merge_erhua( sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos) sub_initials, sub_finals, word, pos)
initials.append(sub_initials) initials.append(sub_initials)
finals.append(sub_finals) finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word) # assert len(sub_initials) == len(sub_finals) == len(word)
# sum(iterable[, start])
initials = sum(initials, []) initials = sum(initials, [])
finals = sum(finals, []) finals = sum(finals, [])
@ -287,111 +388,34 @@ class Frontend():
# we discriminate i, ii and iii # we discriminate i, ii and iii
if c and c not in self.punc: if c and c not in self.punc:
phones.append(c) phones.append(c)
# replace punctuation by `sp`
if c and c in self.punc: if c and c in self.punc:
phones.append('sp') phones.append('sp')
if v and v not in self.punc and v not in self.rhy_phns:
phones.append(v)
phones_list.append(phones)
if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
return phones_list
def _split_word_to_char(self, words):
res = []
for x in words:
res.append(x)
return res
# if using ssml, have pingyin specified, assign pinyin to words
def _g2p_assign(self,
words: List[str],
pinyin_spec: List[str],
merge_sentences: bool=True) -> List[List[str]]:
phones_list = []
initials = []
finals = []
words = self._split_word_to_char(words[0])
for pinyin, char in zip(pinyin_spec, words):
sub_initials = []
sub_finals = []
pinyin = pinyin.replace("u:", "v")
#self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1])
elif len(initial_final_list) == 1:
sub_initials.append('')
sub_finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin)
sub_finals.append(pinyin)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
phones = []
for c, v in zip(initials, finals):
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c and c not in self.punc:
phones.append(c)
if c and c in self.punc:
phones.append('sp')
if v and v not in self.punc and v not in self.rhy_phns: if v and v not in self.punc and v not in self.rhy_phns:
phones.append(v) phones.append(v)
phones_list.append(phones) phones_list.append(phones)
# merge split sub sentence into one sentence.
if merge_sentences: if merge_sentences:
# sub sentence phonemes
merge_list = sum(phones_list, []) merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end # rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end # cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp': if merge_list[-1] == 'sp':
merge_list = merge_list[:-1] merge_list = merge_list[:-1]
# sentence phonemes
phones_list = [] phones_list = []
phones_list.append(merge_list) phones_list.append(merge_list)
return phones_list
def _merge_erhua(self,
initials: List[str],
finals: List[str],
word: str,
pos: str) -> List[List[str]]:
# fix er1
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn == 'er1':
finals[i] = 'er2'
if word not in self.must_erhua and (word in self.not_erhua or
pos in {"a", "j", "nr"}):
return initials, finals
# "……" 等情况直接返回
if len(finals) != len(word):
return initials, finals
assert len(finals) == len(word) return phones_list
new_initials = []
new_finals = []
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn in {
"er2", "er5"
} and word[-2:] not in self.not_erhua and new_finals:
new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1]
else:
new_finals.append(phn)
new_initials.append(initials[i])
return new_initials, new_finals
def _p2id(self, phonemes: List[str]) -> np.ndarray: def _p2id(self, phonemes: List[str]) -> np.ndarray:
"""
Phoneme to Index
"""
# replace unk phone with sp # replace unk phone with sp
phonemes = [ phonemes = [
phn if phn in self.vocab_phones else "sp" for phn in phonemes phn if phn in self.vocab_phones else "sp" for phn in phonemes
@ -400,6 +424,9 @@ class Frontend():
return np.array(phone_ids, np.int64) return np.array(phone_ids, np.int64)
def _t2id(self, tones: List[str]) -> np.ndarray: def _t2id(self, tones: List[str]) -> np.ndarray:
"""
Tone to Index.
"""
# replace unk phone with sp # replace unk phone with sp
tones = [tone if tone in self.vocab_tones else "0" for tone in tones] tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
tone_ids = [self.vocab_tones[item] for item in tones] tone_ids = [self.vocab_tones[item] for item in tones]
@ -407,6 +434,9 @@ class Frontend():
def _get_phone_tone(self, phonemes: List[str], def _get_phone_tone(self, phonemes: List[str],
get_tone_ids: bool=False) -> List[List[str]]: get_tone_ids: bool=False) -> List[List[str]]:
"""
Get tone from phonemes.
"""
phones = [] phones = []
tones = [] tones = []
if get_tone_ids and self.vocab_tones: if get_tone_ids and self.vocab_tones:
@ -423,13 +453,14 @@ class Frontend():
-1] == 'r' and phone not in self.vocab_phones and phone[: -1] == 'r' and phone not in self.vocab_phones and phone[:
-1] in self.vocab_phones: -1] in self.vocab_phones:
phones.append(phone[:-1]) phones.append(phone[:-1])
phones.append("er")
tones.append(tone) tones.append(tone)
phones.append("er")
tones.append("2") tones.append("2")
else: else:
phones.append(phone) phones.append(phone)
tones.append(tone) tones.append(tone)
else: else:
# initals with 0 tone.
phones.append(full_phone) phones.append(full_phone)
tones.append('0') tones.append('0')
else: else:
@ -443,6 +474,7 @@ class Frontend():
phones.append("er2") phones.append("er2")
else: else:
phones.append(phone) phones.append(phone)
return phones, tones return phones, tones
def get_phonemes(self, def get_phonemes(self,
@ -451,10 +483,16 @@ class Frontend():
with_erhua: bool=True, with_erhua: bool=True,
robot: bool=False, robot: bool=False,
print_info: bool=False) -> List[List[str]]: print_info: bool=False) -> List[List[str]]:
"""
Main function to do G2P
"""
# TN & Text Segmentation
sentences = self.text_normalizer.normalize(sentence) sentences = self.text_normalizer.normalize(sentence)
# Prosody & WS & g2p & tone sandhi
phonemes = self._g2p( phonemes = self._g2p(
sentences, merge_sentences=merge_sentences, with_erhua=with_erhua) sentences, merge_sentences=merge_sentences, with_erhua=with_erhua)
# change all tones to `1`
# simulate robot pronunciation, change all tones to `1`
if robot: if robot:
new_phonemes = [] new_phonemes = []
for sentence in phonemes: for sentence in phonemes:
@ -466,6 +504,7 @@ class Frontend():
new_sentence.append(item) new_sentence.append(item)
new_phonemes.append(new_sentence) new_phonemes.append(new_sentence)
phonemes = new_phonemes phonemes = new_phonemes
if print_info: if print_info:
print("----------------------------") print("----------------------------")
print("text norm results:") print("text norm results:")
@ -476,25 +515,101 @@ class Frontend():
print("----------------------------") print("----------------------------")
return phonemes return phonemes
#@an added for ssml pinyin def _split_word_to_char(self, words):
res = []
for x in words:
res.append(x)
return res
# if using ssml, have pingyin specified, assign pinyin to words
def _g2p_assign(self,
words: List[str],
pinyin_spec: List[str],
merge_sentences: bool=True) -> List[List[str]]:
"""
Replace phoneme by SSML
"""
phones_list = []
initials = []
finals = []
# to charactor list
words = self._split_word_to_char(words[0])
for pinyin, char in zip(pinyin_spec, words):
sub_initials = []
sub_finals = []
pinyin = pinyin.replace("u:", "v")
#self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1])
elif len(initial_final_list) == 1:
sub_initials.append('')
sub_finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin)
sub_finals.append(pinyin)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
phones = []
for c, v in zip(initials, finals):
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c and c not in self.punc:
phones.append(c)
# replace punc to `sp`
if c and c in self.punc:
phones.append('sp')
if v and v not in self.punc and v not in self.rhy_phns:
phones.append(v)
phones_list.append(phones)
if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
return phones_list
def get_phonemes_ssml(self, def get_phonemes_ssml(self,
ssml_inputs: list, ssml_inputs: list,
merge_sentences: bool=True, merge_sentences: bool=True,
with_erhua: bool=True, with_erhua: bool=True,
robot: bool=False, robot: bool=False,
print_info: bool=False) -> List[List[str]]: print_info: bool=False) -> List[List[str]]:
"""
Main function to do G2P with SSML support.
"""
all_phonemes = [] all_phonemes = []
for word_pinyin_item in ssml_inputs: for word_pinyin_item in ssml_inputs:
phonemes = [] phonemes = []
print("ssml inputs:", word_pinyin_item)
sentence, pinyin_spec = itemgetter(0, 1)(word_pinyin_item) sentence, pinyin_spec = itemgetter(0, 1)(word_pinyin_item)
print('ssml g2p:', sentence, pinyin_spec)
# TN & Text Segmentation
sentences = self.text_normalizer.normalize(sentence) sentences = self.text_normalizer.normalize(sentence)
if len(pinyin_spec) == 0: if len(pinyin_spec) == 0:
# g2p word w/o specified <say-as>
phonemes = self._g2p( phonemes = self._g2p(
sentences, sentences,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
with_erhua=with_erhua) with_erhua=with_erhua)
else: else:
# phonemes should be pinyin_spec # word phonemes specified by <say-as>
phonemes = self._g2p_assign( phonemes = self._g2p_assign(
sentences, pinyin_spec, merge_sentences=merge_sentences) sentences, pinyin_spec, merge_sentences=merge_sentences)
@ -523,6 +638,9 @@ class Frontend():
return [sum(all_phonemes, [])] return [sum(all_phonemes, [])]
def add_sp_if_no(self, phonemes): def add_sp_if_no(self, phonemes):
"""
Prosody mark #4 added at sentence end.
"""
if not phonemes[-1][-1].startswith('sp'): if not phonemes[-1][-1].startswith('sp'):
phonemes[-1].append('sp4') phonemes[-1].append('sp4')
return phonemes return phonemes
@ -542,8 +660,11 @@ class Frontend():
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
print_info=print_info, print_info=print_info,
robot=robot) robot=robot)
# add #4 for sentence end.
if self.use_rhy: if self.use_rhy:
phonemes = self.add_sp_if_no(phonemes) phonemes = self.add_sp_if_no(phonemes)
result = {} result = {}
phones = [] phones = []
tones = [] tones = []
@ -551,28 +672,33 @@ class Frontend():
temp_tone_ids = [] temp_tone_ids = []
for part_phonemes in phonemes: for part_phonemes in phonemes:
phones, tones = self._get_phone_tone( phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids) part_phonemes, get_tone_ids=get_tone_ids)
if add_blank: if add_blank:
phones = insert_after_character(phones, blank_token) phones = insert_after_character(phones, blank_token)
if tones: if tones:
tone_ids = self._t2id(tones) tone_ids = self._t2id(tones)
if to_tensor: if to_tensor:
tone_ids = paddle.to_tensor(tone_ids) tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids) temp_tone_ids.append(tone_ids)
if phones: if phones:
phone_ids = self._p2id(phones) phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low # if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor: if to_tensor:
phone_ids = paddle.to_tensor(phone_ids) phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids) temp_phone_ids.append(phone_ids)
if temp_tone_ids: if temp_tone_ids:
result["tone_ids"] = temp_tone_ids result["tone_ids"] = temp_tone_ids
if temp_phone_ids: if temp_phone_ids:
result["phone_ids"] = temp_phone_ids result["phone_ids"] = temp_phone_ids
return result return result
# @an added for ssml
def get_input_ids_ssml( def get_input_ids_ssml(
self, self,
sentence: str, sentence: str,
@ -584,12 +710,15 @@ class Frontend():
blank_token: str="<pad>", blank_token: str="<pad>",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
# split setence by SSML tag.
l_inputs = MixTextProcessor.get_pinyin_split(sentence) l_inputs = MixTextProcessor.get_pinyin_split(sentence)
phonemes = self.get_phonemes_ssml( phonemes = self.get_phonemes_ssml(
l_inputs, l_inputs,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
print_info=print_info, print_info=print_info,
robot=robot) robot=robot)
result = {} result = {}
phones = [] phones = []
tones = [] tones = []
@ -599,21 +728,26 @@ class Frontend():
for part_phonemes in phonemes: for part_phonemes in phonemes:
phones, tones = self._get_phone_tone( phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids) part_phonemes, get_tone_ids=get_tone_ids)
if add_blank: if add_blank:
phones = insert_after_character(phones, blank_token) phones = insert_after_character(phones, blank_token)
if tones: if tones:
tone_ids = self._t2id(tones) tone_ids = self._t2id(tones)
if to_tensor: if to_tensor:
tone_ids = paddle.to_tensor(tone_ids) tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids) temp_tone_ids.append(tone_ids)
if phones: if phones:
phone_ids = self._p2id(phones) phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low # if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor: if to_tensor:
phone_ids = paddle.to_tensor(phone_ids) phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids) temp_phone_ids.append(phone_ids)
if temp_tone_ids: if temp_tone_ids:
result["tone_ids"] = temp_tone_ids result["tone_ids"] = temp_tone_ids
if temp_phone_ids: if temp_phone_ids:
result["phone_ids"] = temp_phone_ids result["phone_ids"] = temp_phone_ids
return result return result

Loading…
Cancel
Save