Merge pull request #2230 from BarryKCL/develop_g2pW

Add g2pW to Chinese frontend
pull/2249/head
TianYuan 2 years ago committed by GitHub
commit a75b2a5bab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
include paddlespeech/t2s/exps/*.txt
include paddlespeech/t2s/frontend/*.yaml

@ -699,6 +699,7 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P
## Acknowledgement
- Many thanks to [BarryKCL](https://github.com/BarryKCL) improved TTS Chinses frontend based on [G2PW](https://github.com/GitYCC/g2pW)
- Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help.
- Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files.
- Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function.

@ -833,6 +833,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
## 致谢
- 非常感谢 [BarryKCL](https://github.com/BarryKCL)基于[G2PW](https://github.com/GitYCC/g2pW)对TTS中文文本前端的优化。
- 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) 多年来的关注和建议,以及在诸多问题上的帮助。
- 非常感谢 [mymagicpower](https://github.com/mymagicpower) 采用PaddleSpeech 对 ASR 的[短语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk)及[长语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk)进行 Java 实现。
- 非常感谢 [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) 采用 PaddleSpeech 语音合成功能实现 Virtual Uploader(VUP)/Virtual YouTuber(VTuber) 虚拟主播。

@ -19,6 +19,7 @@ loguru
matplotlib
nara_wpe
onnxruntime==1.10.0
opencc
pandas
paddlenlp
paddlespeech_feat

@ -40,3 +40,6 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
* [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING)
- zlib License
- ThreadPool
* [g2pW](https://github.com/GitYCC/g2pW/blob/master/LICENCE)
- Apache-2.0 license

@ -1335,3 +1335,17 @@ kws_dynamic_pretrained_models = {
},
},
}
# ---------------------------------
# ------------- G2PW ---------------
# ---------------------------------
g2pw_onnx_models = {
'G2PWModel': {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel.tar',
'md5':
'63bc0894af15a5a591e58b2130a2bcac',
},
},
}

@ -0,0 +1,2 @@
from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter

@ -0,0 +1,134 @@
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import numpy as np
from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
ANCHOR_CHAR = ''
def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids, phonemes=None, pos_tags=None,
use_mask=False, use_char_phoneme=False, use_pos=False, window_size=None, max_len=512):
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(window_size, texts, query_ids)
input_ids = []
token_type_ids = []
attention_masks = []
phoneme_masks = []
char_ids = []
position_ids = []
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(max_len, text, query_id, tokens, text2token, token2text)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels)
char_id = chars.index(query_char)
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
attention_masks.append(attention_mask)
phoneme_masks.append(phoneme_mask)
char_ids.append(char_id)
position_ids.append(position_id)
outputs = {
'input_ids': np.array(input_ids),
'token_type_ids': np.array(token_type_ids),
'attention_masks': np.array(attention_masks),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids),
'position_ids': np.array(position_ids),
}
return outputs
def _truncate_texts(window_size, texts, query_ids):
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
start = max(0, query_id - window_size // 2)
end = min(len(text), query_id + window_size // 2)
truncated_text = text[start:end]
truncated_texts.append(truncated_text)
truncated_query_id = query_id - start
truncated_query_ids.append(truncated_query_id)
return truncated_texts, truncated_query_ids
def _truncate(max_len, text, query_id, tokens, text2token, token2text):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
token_position = text2token[query_id]
token_start = token_position - truncate_len // 2
token_end = token_start + truncate_len
font_exceed_dist = -token_start
back_exceed_dist = token_end - len(tokens)
if font_exceed_dist > 0:
token_start += font_exceed_dist
token_end += font_exceed_dist
elif back_exceed_dist > 0:
token_start -= back_exceed_dist
token_end -= back_exceed_dist
start = token2text[token_start][0]
end = token2text[token_end - 1][1]
return (
text[start:end],
query_id - start,
tokens[token_start:token_end],
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]]
)
def prepare_data(sent_path, lb_path=None):
raw_texts = open(sent_path).read().rstrip().split('\n')
query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts]
texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts]
if lb_path is None:
return texts, query_ids
else:
phonemes = open(lb_path).read().rstrip().split('\n')
return texts, query_ids, phonemes
def get_phoneme_labels(polyphonic_chars):
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(phoneme))
return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars):
labels = sorted(list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
return labels, char2phonemes

@ -0,0 +1,143 @@
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import json
import onnxruntime
import numpy as np
from opencc import OpenCC
from pypinyin import pinyin, lazy_pinyin, Style
from paddlenlp.transformers import BertTokenizer
from paddlespeech.utils.env import MODEL_HOME
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data,\
prepare_onnx_input,\
get_phoneme_labels,\
get_char_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.utils import load_config
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.resource.pretrained_models import g2pw_onnx_models
def predict(session, onnx_input, labels):
all_preds = []
all_confidences = []
probs = session.run([],{"input_ids": onnx_input['input_ids'],
"token_type_ids":onnx_input['token_type_ids'],
"attention_mask":onnx_input['attention_masks'],
"phoneme_mask":onnx_input['phoneme_masks'],
"char_ids":onnx_input['char_ids'],
"position_ids":onnx_input['position_ids']})[0]
preds = np.argmax(probs,axis=1).tolist()
max_probs = []
for index,arr in zip(preds,probs.tolist()):
max_probs.append(arr[index])
all_preds += [labels[pred] for pred in preds]
all_confidences += max_probs
return all_preds, all_confidences
class G2PWOnnxConverter:
def __init__(self, model_dir = MODEL_HOME, style='bopomofo', model_source=None, enable_non_tradional_chinese=False):
if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')):
uncompress_path = download_and_decompress(g2pw_onnx_models['G2PWModel']['1.0'],model_dir)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),sess_options=sess_options)
self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)
polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt')
monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt')
self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')]
self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')]
self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars)
self.chars = sorted(list(self.char2phonemes.keys()))
self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI']
with open(os.path.join(model_dir,'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
'bopomofo': lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin,
}[style]
with open(os.path.join(model_dir,'G2PWModel/char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
self.cc = OpenCC('s2tw')
def _convert_bopomofo_to_pinyin(self, bopomofo):
tone = bopomofo[-1]
assert tone in '12345'
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
else:
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
def __call__(self, sentences):
if isinstance(sentences, str):
sentences = [sentences]
if self.enable_opencc:
translated_sentences = []
for sent in sentences:
translated_sent = self.cc.convert(sent)
assert len(translated_sent) == len(sent)
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids,
use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme,
window_size=None)
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
results[sent_id][query_id] = self.style_convert_func(pred)
return results
def _prepare_data(self, sentences):
polyphonic_chars = set(self.chars)
monophonic_chars_dict = {
char: phoneme for char, phoneme in self.monophonic_chars
}
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
pypinyin_result = pinyin(sent,style=Style.TONE3)
partial_result = [None] * len(sent)
for i, char in enumerate(sent):
if char in polyphonic_chars:
texts.append(sent)
query_ids.append(i)
sent_ids.append(sent_id)
elif char in monophonic_chars_dict:
partial_result[i] = self.style_convert_func(monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
partial_results.append(partial_result)
return texts, query_ids, sent_ids, partial_results

@ -0,0 +1,133 @@
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import re
import sys
def wordize_and_map(text):
words = []
index_map_from_text_to_word = []
index_map_from_word_to_text = []
while len(text) > 0:
match_space = re.match(r'^ +', text)
if match_space:
space_str = match_space.group(0)
index_map_from_text_to_word += [None] * len(space_str)
text = text[len(space_str):]
continue
match_en = re.match(r'^[a-zA-Z0-9]+', text)
if match_en:
en_word = match_en.group(0)
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + len(en_word)
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
index_map_from_text_to_word += [len(words)] * len(en_word)
words.append(en_word)
text = text[len(en_word):]
else:
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + 1
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
index_map_from_text_to_word += [len(words)]
words.append(text[0])
text = text[1:]
return words, index_map_from_text_to_word, index_map_from_word_to_text
def tokenize_and_map(tokenizer, text):
words, text2word, word2text = wordize_and_map(text)
tokens = []
index_map_from_token_to_text = []
for word, (word_start, word_end) in zip(words, word2text):
word_tokens = tokenizer.tokenize(word)
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
index_map_from_token_to_text.append((word_start, word_end))
tokens.append('[UNK]')
else:
current_word_start = word_start
for word_token in word_tokens:
word_token_len = len(re.sub(r'^##', '', word_token))
index_map_from_token_to_text.append(
(current_word_start, current_word_start + word_token_len))
current_word_start = current_word_start + word_token_len
tokens.append(word_token)
index_map_from_text_to_token = text2word
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text):
for token_pos in range(token_start, token_end):
index_map_from_text_to_token[token_pos] = i
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
def _load_config(config_path):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
return config
default_config_dict = {
'manual_seed': 1313,
'model_source': 'bert-base-chinese',
'window_size': 32,
'num_workers': 2,
'use_mask': True,
'use_char_phoneme': False,
'use_conditional': True,
'param_conditional': {
'affect_location': 'softmax',
'bias': True,
'char-linear': True,
'pos-linear': False,
'char+pos-second': True,
'char+pos-second_lowrank': False,
'lowrank_size': 0,
'char+pos-second_fm': False,
'fm_size': 0,
'fix_mode': None,
'count_json': 'train.count.json'
},
'lr': 5e-5,
'val_interval': 200,
'num_iter': 10000,
'use_focal': False,
'param_focal': {
'alpha': 0.0,
'gamma': 0.7
},
'use_pos': True,
'param_pos ': {
'weight': 0.1,
'pos_joint_training': True,
'train_pos_path': 'train.pos',
'valid_pos_path': 'dev.pos',
'test_pos_path': 'test.pos'
}
}
def load_config(config_path, use_default=False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():
if not hasattr(config, attr):
setattr(config, attr, val)
elif isinstance(val, dict):
d = getattr(config, attr)
for dict_k, dict_v in val.items():
if dict_k not in d:
d[dict_k] = dict_v
return config

@ -0,0 +1,26 @@
polyphonic:
湖泊: ['hu2','po1']
地壳: ['di4','qiao4']
柏树: ['bai3','shu4']
曝光: ['bao4','guang1']
弹力: ['tan2','li4']
字帖: ['zi4','tie4']
口吃: ['kou3','chi1']
包扎: ['bao1','za1']
哪吒: ['ne2','zha1']
说服: ['shuo1','fu2']
识字: ['shi2','zi4']
骨头: ['gu3','tou5']
对称: ['dui4','chen4']
口供: ['kou3','gong4']
抹布: ['ma1','bu4']
露背: ['lu4','bei4']
圈养: ['juan4', 'yang3']
眼眶: ['yan3', 'kuang4']
品行: ['pin3','xing2']
颤抖: ['chan4','dou3']
差不多: ['cha4','bu5','duo1']
鸭绿江: ['ya1','lu4','jiang1']
撒切尔: ['sa4','qie4','er3']
比比皆是: ['bi3','bi3','jie1','shi4']
身无长物: ['shen1','wu2','chang2','wu4']

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os
import yaml
from typing import Dict
from typing import List
@ -25,6 +27,7 @@ from pypinyin import load_single_dict
from pypinyin import Style
from pypinyin_dict.phrase_pinyin_data import large_pinyin
from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter
from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon
from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
@ -53,20 +56,42 @@ def insert_after_character(lst, item):
return result
class Polyphonic():
def __init__(self):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),
'polyphonic.yaml'), 'r',encoding='utf-8') as polyphonic_file:
# 解析yaml
polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader)
self.polyphonic_words = polyphonic_dict["polyphonic"]
def correct_pronunciation(self,word,pinyin):
# 词汇被词典收录则返回纠正后的读音
if word in self.polyphonic_words.keys():
pinyin = self.polyphonic_words[word]
# 否则返回原读音
return pinyin
class Frontend():
def __init__(self,
g2p_model="pypinyin",
g2p_model="g2pW",
phone_vocab_path=None,
tone_vocab_path=None):
self.tone_modifier = ToneSandhi()
self.text_normalizer = TextNormalizer()
self.punc = ":,;。?!“”‘’':,;.?!"
# g2p_model can be pypinyin and g2pM
# g2p_model can be pypinyin and g2pM and g2pW
self.g2p_model = g2p_model
if self.g2p_model == "g2pM":
self.g2pM_model = G2pM()
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
elif self.g2p_model == "g2pW":
self.corrector = Polyphonic()
self.g2pM_model = G2pM()
self.g2pW_model = G2PWOnnxConverter(style='pinyin', enable_non_tradional_chinese=True)
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
else:
self.__init__pypinyin()
self.must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿"}
@ -156,18 +181,63 @@ class Frontend():
initials = []
finals = []
seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut)
for word, pos in seg_cut:
if pos == 'eng':
continue
sub_initials, sub_finals = self._get_initials_finals(word)
sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals)
if with_erhua:
sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
# 为了多音词获得更好的效果,这里采用整句预测
if self.g2p_model == "g2pW":
try:
pinyins = self.g2pW_model(seg)[0]
except Exception:
# g2pW采用模型采用繁体输入如果有cover不了的简体词采用g2pM预测
print("[%s] not in g2pW dict,use g2pM"%seg)
pinyins = self.g2pM_model(seg, tone=True, char_split=False)
pre_word_length = 0
for word, pos in seg_cut:
sub_initials = []
sub_finals = []
now_word_length = pre_word_length + len(word)
if pos == 'eng':
pre_word_length = now_word_length
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
# 矫正发音
word_pinyins = self.corrector.correct_pronunciation(word,word_pinyins)
for pinyin,char in zip(word_pinyins,word):
if pinyin == None:
pinyin = char
pinyin = pinyin.replace("u:", "v")
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1])
elif len(initial_final_list) == 1:
sub_initials.append('')
sub_finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin)
sub_finals.append(pinyin)
pre_word_length = now_word_length
sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals)
if with_erhua:
sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
else:
for word, pos in seg_cut:
if pos == 'eng':
continue
sub_initials, sub_finals = self._get_initials_finals(word)
sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals)
if with_erhua:
sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, [])
finals = sum(finals, [])

@ -45,6 +45,7 @@ base = [
"matplotlib",
"nara_wpe",
"onnxruntime==1.10.0",
"opencc",
"pandas",
"paddlenlp",
"paddlespeech_feat",
@ -329,4 +330,4 @@ setup_info = dict(
})
with version_info():
setup(**setup_info)
setup(**setup_info,include_package_data=True)

Loading…
Cancel
Save