Merge pull request #2250 from yt605155624/format_g2pw

[TTS]format g2pw
pull/2261/head
TianYuan 2 years ago committed by GitHub
commit 83e10fadd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,65 +1,95 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
This code is modified from https://github.com/GitYCC/g2pW
"""
import numpy as np
from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
ANCHOR_CHAR = ''
def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids, phonemes=None, pos_tags=None,
use_mask=False, use_char_phoneme=False, use_pos=False, window_size=None, max_len=512):
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(window_size, texts, query_ids)
input_ids = []
token_type_ids = []
attention_masks = []
phoneme_masks = []
char_ids = []
position_ids = []
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(max_len, text, query_id, tokens, text2token, token2text)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels)
char_id = chars.index(query_char)
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
attention_masks.append(attention_mask)
phoneme_masks.append(phoneme_mask)
char_ids.append(char_id)
position_ids.append(position_id)
outputs = {
'input_ids': np.array(input_ids),
'token_type_ids': np.array(token_type_ids),
'attention_masks': np.array(attention_masks),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids),
'position_ids': np.array(position_ids),
}
return outputs
def prepare_onnx_input(tokenizer,
labels,
char2phonemes,
chars,
texts,
query_ids,
phonemes=None,
pos_tags=None,
use_mask=False,
use_char_phoneme=False,
use_pos=False,
window_size=None,
max_len=512):
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(window_size,
texts, query_ids)
input_ids = []
token_type_ids = []
attention_masks = []
phoneme_masks = []
char_ids = []
position_ids = []
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len, text, query_id, tokens, text2token, token2text)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
input_id = list(
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels)
char_id = chars.index(query_char)
position_id = text2token[
query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
attention_masks.append(attention_mask)
phoneme_masks.append(phoneme_mask)
char_ids.append(char_id)
position_ids.append(position_id)
outputs = {
'input_ids': np.array(input_ids),
'token_type_ids': np.array(token_type_ids),
'attention_masks': np.array(attention_masks),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids),
'position_ids': np.array(position_ids),
}
return outputs
def _truncate_texts(window_size, texts, query_ids):
truncated_texts = []
@ -74,6 +104,7 @@ def _truncate_texts(window_size, texts, query_ids):
truncated_query_ids.append(truncated_query_id)
return truncated_texts, truncated_query_ids
def _truncate(max_len, text, query_id, tokens, text2token, token2text):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
@ -95,13 +126,11 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
start = token2text[token_start][0]
end = token2text[token_end - 1][1]
return (
text[start:end],
query_id - start,
tokens[token_start:token_end],
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]]
)
return (text[start:end], query_id - start, tokens[token_start:token_end], [
i - token_start if i is not None else None
for i in text2token[start:end]
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
def prepare_data(sent_path, lb_path=None):
raw_texts = open(sent_path).read().rstrip().split('\n')
@ -125,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
def get_char_phoneme_labels(polyphonic_chars):
labels = sorted(list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:

@ -1,38 +1,54 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import json
import onnxruntime
import numpy as np
import os
import numpy as np
import onnxruntime
from opencc import OpenCC
from pypinyin import pinyin, lazy_pinyin, Style
from paddlenlp.transformers import BertTokenizer
from paddlespeech.utils.env import MODEL_HOME
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data,\
prepare_onnx_input,\
get_phoneme_labels,\
get_char_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.utils import load_config
from pypinyin import pinyin
from pypinyin import Style
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.resource.pretrained_models import g2pw_onnx_models
from paddlespeech.t2s.frontend.g2pw.dataset import get_char_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.dataset import get_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_onnx_input
from paddlespeech.t2s.frontend.g2pw.utils import load_config
from paddlespeech.utils.env import MODEL_HOME
def predict(session, onnx_input, labels):
all_preds = []
all_confidences = []
probs = session.run([],{"input_ids": onnx_input['input_ids'],
"token_type_ids":onnx_input['token_type_ids'],
"attention_mask":onnx_input['attention_masks'],
"phoneme_mask":onnx_input['phoneme_masks'],
"char_ids":onnx_input['char_ids'],
"position_ids":onnx_input['position_ids']})[0]
preds = np.argmax(probs,axis=1).tolist()
probs = session.run([], {
"input_ids": onnx_input['input_ids'],
"token_type_ids": onnx_input['token_type_ids'],
"attention_mask": onnx_input['attention_masks'],
"phoneme_mask": onnx_input['phoneme_masks'],
"char_ids": onnx_input['char_ids'],
"position_ids": onnx_input['position_ids']
})[0]
preds = np.argmax(probs, axis=1).tolist()
max_probs = []
for index,arr in zip(preds,probs.tolist()):
for index, arr in zip(preds, probs.tolist()):
max_probs.append(arr[index])
all_preds += [labels[pred] for pred in preds]
all_confidences += max_probs
@ -41,39 +57,69 @@ def predict(session, onnx_input, labels):
class G2PWOnnxConverter:
def __init__(self, model_dir = MODEL_HOME, style='bopomofo', model_source=None, enable_non_tradional_chinese=False):
def __init__(self,
model_dir=MODEL_HOME,
style='bopomofo',
model_source=None,
enable_non_tradional_chinese=False):
if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')):
uncompress_path = download_and_decompress(g2pw_onnx_models['G2PWModel']['1.0'],model_dir)
uncompress_path = download_and_decompress(
g2pw_onnx_models['G2PWModel']['1.0'], model_dir)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),sess_options=sess_options)
self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),
sess_options=sess_options)
self.config = load_config(
os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)
polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt')
monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt')
self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')]
self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')]
self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars)
polyphonic_chars_path = os.path.join(model_dir,
'G2PWModel/POLYPHONIC_CHARS.txt')
monophonic_chars_path = os.path.join(model_dir,
'G2PWModel/MONOPHONIC_CHARS.txt')
self.polyphonic_chars = [
line.split('\t')
for line in open(polyphonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
]
self.monophonic_chars = [
line.split('\t')
for line in open(monophonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
]
self.labels, self.char2phonemes = get_char_phoneme_labels(
self.polyphonic_chars
) if self.config.use_char_phoneme else get_phoneme_labels(
self.polyphonic_chars)
self.chars = sorted(list(self.char2phonemes.keys()))
self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI']
with open(os.path.join(model_dir,'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr:
self.pos_tags = [
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
]
with open(
os.path.join(model_dir,
'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'),
'r',
encoding='utf-8') as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
'bopomofo': lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin,
}[style]
with open(os.path.join(model_dir,'G2PWModel/char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr:
with open(
os.path.join(model_dir, 'G2PWModel/char_bopomofo_dict.json'),
'r',
encoding='utf-8') as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
@ -100,15 +146,23 @@ class G2PWOnnxConverter:
assert len(translated_sent) == len(sent)
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences)
texts, query_ids, sent_ids, partial_results = self._prepare_data(
sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids,
use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme,
window_size=None)
onnx_input = prepare_onnx_input(
self.tokenizer,
self.labels,
self.char2phonemes,
self.chars,
texts,
query_ids,
use_mask=self.config.use_mask,
use_char_phoneme=self.config.use_char_phoneme,
window_size=None)
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
if self.config.use_char_phoneme:
@ -123,11 +177,12 @@ class G2PWOnnxConverter:
def _prepare_data(self, sentences):
polyphonic_chars = set(self.chars)
monophonic_chars_dict = {
char: phoneme for char, phoneme in self.monophonic_chars
char: phoneme
for char, phoneme in self.monophonic_chars
}
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
pypinyin_result = pinyin(sent,style=Style.TONE3)
pypinyin_result = pinyin(sent, style=Style.TONE3)
partial_result = [None] * len(sent)
for i, char in enumerate(sent):
if char in polyphonic_chars:
@ -135,9 +190,10 @@ class G2PWOnnxConverter:
query_ids.append(i)
sent_ids.append(sent_id)
elif char in monophonic_chars_dict:
partial_result[i] = self.style_convert_func(monophonic_chars_dict[char])
partial_result[i] = self.style_convert_func(
monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0]
partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
partial_results.append(partial_result)
return texts, query_ids, sent_ids, partial_results

@ -1,10 +1,22 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
This code is modified from https://github.com/GitYCC/g2pW
"""
import re
import sys
def wordize_and_map(text):
words = []
@ -92,7 +104,6 @@ default_config_dict = {
'char-linear': True,
'pos-linear': False,
'char+pos-second': True,
'char+pos-second_lowrank': False,
'lowrank_size': 0,
'char+pos-second_fm': False,
@ -130,4 +141,4 @@ def load_config(config_path, use_default=False):
for dict_k, dict_v in val.items():
if dict_k not in d:
d[dict_k] = dict_v
return config
return config

@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os
import yaml
import re
from typing import Dict
from typing import List
import jieba.posseg as psg
import numpy as np
import paddle
import yaml
from g2pM import G2pM
from pypinyin import lazy_pinyin
from pypinyin import load_phrases_dict
@ -58,19 +58,24 @@ def insert_after_character(lst, item):
class Polyphonic():
def __init__(self):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),
'polyphonic.yaml'), 'r',encoding='utf-8') as polyphonic_file:
with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'polyphonic.yaml'),
'r',
encoding='utf-8') as polyphonic_file:
# 解析yaml
polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader)
self.polyphonic_words = polyphonic_dict["polyphonic"]
def correct_pronunciation(self,word,pinyin):
def correct_pronunciation(self, word, pinyin):
# 词汇被词典收录则返回纠正后的读音
if word in self.polyphonic_words.keys():
pinyin = self.polyphonic_words[word]
# 否则返回原读音
return pinyin
class Frontend():
def __init__(self,
g2p_model="g2pW",
@ -88,7 +93,8 @@ class Frontend():
elif self.g2p_model == "g2pW":
self.corrector = Polyphonic()
self.g2pM_model = G2pM()
self.g2pW_model = G2PWOnnxConverter(style='pinyin', enable_non_tradional_chinese=True)
self.g2pW_model = G2PWOnnxConverter(
style='pinyin', enable_non_tradional_chinese=True)
self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False)
@ -187,7 +193,7 @@ class Frontend():
pinyins = self.g2pW_model(seg)[0]
except Exception:
# g2pW采用模型采用繁体输入如果有cover不了的简体词采用g2pM预测
print("[%s] not in g2pW dict,use g2pM"%seg)
print("[%s] not in g2pW dict,use g2pM" % seg)
pinyins = self.g2pM_model(seg, tone=True, char_split=False)
pre_word_length = 0
for word, pos in seg_cut:
@ -199,13 +205,15 @@ class Frontend():
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
# 矫正发音
word_pinyins = self.corrector.correct_pronunciation(word,word_pinyins)
for pinyin,char in zip(word_pinyins,word):
if pinyin == None:
word_pinyins = self.corrector.correct_pronunciation(
word, word_pinyins)
for pinyin, char in zip(word_pinyins, word):
if pinyin is None:
pinyin = char
pinyin = pinyin.replace("u:", "v")
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
initial_final_list = self.pinyin2phone[
pinyin].split(" ")
if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1])
@ -218,7 +226,7 @@ class Frontend():
sub_finals.append(pinyin)
pre_word_length = now_word_length
sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals)
sub_finals)
if with_erhua:
sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos)
@ -231,7 +239,7 @@ class Frontend():
continue
sub_initials, sub_finals = self._get_initials_finals(word)
sub_finals = self.tone_modifier.modified_tone(word, pos,
sub_finals)
sub_finals)
if with_erhua:
sub_initials, sub_finals = self._merge_erhua(
sub_initials, sub_finals, word, pos)

Loading…
Cancel
Save