Merge pull request #2250 from yt605155624/format_g2pw

[TTS]format g2pw
pull/2261/head
TianYuan 2 years ago committed by GitHub
commit 83e10fadd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,17 +1,43 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Credits Credits
This code is modified from https://github.com/GitYCC/g2pW This code is modified from https://github.com/GitYCC/g2pW
""" """
import numpy as np import numpy as np
from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
ANCHOR_CHAR = '' ANCHOR_CHAR = ''
def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids, phonemes=None, pos_tags=None, def prepare_onnx_input(tokenizer,
use_mask=False, use_char_phoneme=False, use_pos=False, window_size=None, max_len=512): labels,
char2phonemes,
chars,
texts,
query_ids,
phonemes=None,
pos_tags=None,
use_mask=False,
use_char_phoneme=False,
use_pos=False,
window_size=None,
max_len=512):
if window_size is not None: if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(window_size, texts, query_ids) truncated_texts, truncated_query_ids = _truncate_texts(window_size,
texts, query_ids)
input_ids = [] input_ids = []
token_type_ids = [] token_type_ids = []
@ -30,11 +56,13 @@ def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids
print(f'warning: text "{text}" is invalid') print(f'warning: text "{text}" is invalid')
return {} return {}
text, query_id, tokens, text2token, token2text = _truncate(max_len, text, query_id, tokens, text2token, token2text) text, query_id, tokens, text2token, token2text = _truncate(
max_len, text, query_id, tokens, text2token, token2text)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) input_id = list(
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int)) token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int)) attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
@ -42,7 +70,8 @@ def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \ phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels) if use_mask else [1] * len(labels)
char_id = chars.index(query_char) char_id = chars.index(query_char)
position_id = text2token[query_id] + 1 # [CLS] token locate at first place position_id = text2token[
query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id) input_ids.append(input_id)
token_type_ids.append(token_type_id) token_type_ids.append(token_type_id)
@ -61,6 +90,7 @@ def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids
} }
return outputs return outputs
def _truncate_texts(window_size, texts, query_ids): def _truncate_texts(window_size, texts, query_ids):
truncated_texts = [] truncated_texts = []
truncated_query_ids = [] truncated_query_ids = []
@ -74,6 +104,7 @@ def _truncate_texts(window_size, texts, query_ids):
truncated_query_ids.append(truncated_query_id) truncated_query_ids.append(truncated_query_id)
return truncated_texts, truncated_query_ids return truncated_texts, truncated_query_ids
def _truncate(max_len, text, query_id, tokens, text2token, token2text): def _truncate(max_len, text, query_id, tokens, text2token, token2text):
truncate_len = max_len - 2 truncate_len = max_len - 2
if len(tokens) <= truncate_len: if len(tokens) <= truncate_len:
@ -95,13 +126,11 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
start = token2text[token_start][0] start = token2text[token_start][0]
end = token2text[token_end - 1][1] end = token2text[token_end - 1][1]
return ( return (text[start:end], query_id - start, tokens[token_start:token_end], [
text[start:end], i - token_start if i is not None else None
query_id - start, for i in text2token[start:end]
tokens[token_start:token_end], ], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]]
)
def prepare_data(sent_path, lb_path=None): def prepare_data(sent_path, lb_path=None):
raw_texts = open(sent_path).read().rstrip().split('\n') raw_texts = open(sent_path).read().rstrip().split('\n')
@ -125,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
def get_char_phoneme_labels(polyphonic_chars): def get_char_phoneme_labels(polyphonic_chars):
labels = sorted(list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars]))) labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {} char2phonemes = {}
for char, phoneme in polyphonic_chars: for char, phoneme in polyphonic_chars:
if char not in char2phonemes: if char not in char2phonemes:

@ -1,34 +1,50 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Credits Credits
This code is modified from https://github.com/GitYCC/g2pW This code is modified from https://github.com/GitYCC/g2pW
""" """
import os
import json import json
import onnxruntime import os
import numpy as np
import numpy as np
import onnxruntime
from opencc import OpenCC from opencc import OpenCC
from pypinyin import pinyin, lazy_pinyin, Style
from paddlenlp.transformers import BertTokenizer from paddlenlp.transformers import BertTokenizer
from paddlespeech.utils.env import MODEL_HOME from pypinyin import pinyin
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data,\ from pypinyin import Style
prepare_onnx_input,\
get_phoneme_labels,\
get_char_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.utils import load_config
from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.resource.pretrained_models import g2pw_onnx_models from paddlespeech.resource.pretrained_models import g2pw_onnx_models
from paddlespeech.t2s.frontend.g2pw.dataset import get_char_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.dataset import get_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_onnx_input
from paddlespeech.t2s.frontend.g2pw.utils import load_config
from paddlespeech.utils.env import MODEL_HOME
def predict(session, onnx_input, labels): def predict(session, onnx_input, labels):
all_preds = [] all_preds = []
all_confidences = [] all_confidences = []
probs = session.run([],{"input_ids": onnx_input['input_ids'], probs = session.run([], {
"input_ids": onnx_input['input_ids'],
"token_type_ids": onnx_input['token_type_ids'], "token_type_ids": onnx_input['token_type_ids'],
"attention_mask": onnx_input['attention_masks'], "attention_mask": onnx_input['attention_masks'],
"phoneme_mask": onnx_input['phoneme_masks'], "phoneme_mask": onnx_input['phoneme_masks'],
"char_ids": onnx_input['char_ids'], "char_ids": onnx_input['char_ids'],
"position_ids":onnx_input['position_ids']})[0] "position_ids": onnx_input['position_ids']
})[0]
preds = np.argmax(probs, axis=1).tolist() preds = np.argmax(probs, axis=1).tolist()
max_probs = [] max_probs = []
@ -41,39 +57,69 @@ def predict(session, onnx_input, labels):
class G2PWOnnxConverter: class G2PWOnnxConverter:
def __init__(self, model_dir = MODEL_HOME, style='bopomofo', model_source=None, enable_non_tradional_chinese=False): def __init__(self,
model_dir=MODEL_HOME,
style='bopomofo',
model_source=None,
enable_non_tradional_chinese=False):
if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')): if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')):
uncompress_path = download_and_decompress(g2pw_onnx_models['G2PWModel']['1.0'],model_dir) uncompress_path = download_and_decompress(
g2pw_onnx_models['G2PWModel']['1.0'], model_dir)
sess_options = onnxruntime.SessionOptions() sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2 sess_options.intra_op_num_threads = 2
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),sess_options=sess_options) self.session_g2pW = onnxruntime.InferenceSession(
self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True) os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),
sess_options=sess_options)
self.config = load_config(
os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
self.model_source = model_source if model_source else self.config.model_source self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source) self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)
polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt') polyphonic_chars_path = os.path.join(model_dir,
monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt') 'G2PWModel/POLYPHONIC_CHARS.txt')
self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')] monophonic_chars_path = os.path.join(model_dir,
self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')] 'G2PWModel/MONOPHONIC_CHARS.txt')
self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars) self.polyphonic_chars = [
line.split('\t')
for line in open(polyphonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
]
self.monophonic_chars = [
line.split('\t')
for line in open(monophonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
]
self.labels, self.char2phonemes = get_char_phoneme_labels(
self.polyphonic_chars
) if self.config.use_char_phoneme else get_phoneme_labels(
self.polyphonic_chars)
self.chars = sorted(list(self.char2phonemes.keys())) self.chars = sorted(list(self.char2phonemes.keys()))
self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'] self.pos_tags = [
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
with open(os.path.join(model_dir,'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr: ]
with open(
os.path.join(model_dir,
'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'),
'r',
encoding='utf-8') as fr:
self.bopomofo_convert_dict = json.load(fr) self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = { self.style_convert_func = {
'bopomofo': lambda x: x, 'bopomofo': lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin, 'pinyin': self._convert_bopomofo_to_pinyin,
}[style] }[style]
with open(os.path.join(model_dir,'G2PWModel/char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr: with open(
os.path.join(model_dir, 'G2PWModel/char_bopomofo_dict.json'),
'r',
encoding='utf-8') as fr:
self.char_bopomofo_dict = json.load(fr) self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc: if self.enable_opencc:
@ -101,13 +147,21 @@ class G2PWOnnxConverter:
translated_sentences.append(translated_sent) translated_sentences.append(translated_sent)
sentences = translated_sentences sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences) texts, query_ids, sent_ids, partial_results = self._prepare_data(
sentences)
if len(texts) == 0: if len(texts) == 0:
# sentences no polyphonic words # sentences no polyphonic words
return partial_results return partial_results
onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids, onnx_input = prepare_onnx_input(
use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme, self.tokenizer,
self.labels,
self.char2phonemes,
self.chars,
texts,
query_ids,
use_mask=self.config.use_mask,
use_char_phoneme=self.config.use_char_phoneme,
window_size=None) window_size=None)
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels) preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
@ -123,7 +177,8 @@ class G2PWOnnxConverter:
def _prepare_data(self, sentences): def _prepare_data(self, sentences):
polyphonic_chars = set(self.chars) polyphonic_chars = set(self.chars)
monophonic_chars_dict = { monophonic_chars_dict = {
char: phoneme for char, phoneme in self.monophonic_chars char: phoneme
for char, phoneme in self.monophonic_chars
} }
texts, query_ids, sent_ids, partial_results = [], [], [], [] texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences): for sent_id, sent in enumerate(sentences):
@ -135,7 +190,8 @@ class G2PWOnnxConverter:
query_ids.append(i) query_ids.append(i)
sent_ids.append(sent_id) sent_ids.append(sent_id)
elif char in monophonic_chars_dict: elif char in monophonic_chars_dict:
partial_result[i] = self.style_convert_func(monophonic_chars_dict[char]) partial_result[i] = self.style_convert_func(
monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict: elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0] partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])

@ -1,10 +1,22 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Credits Credits
This code is modified from https://github.com/GitYCC/g2pW This code is modified from https://github.com/GitYCC/g2pW
""" """
import re import re
import sys
def wordize_and_map(text): def wordize_and_map(text):
words = [] words = []
@ -92,7 +104,6 @@ default_config_dict = {
'char-linear': True, 'char-linear': True,
'pos-linear': False, 'pos-linear': False,
'char+pos-second': True, 'char+pos-second': True,
'char+pos-second_lowrank': False, 'char+pos-second_lowrank': False,
'lowrank_size': 0, 'lowrank_size': 0,
'char+pos-second_fm': False, 'char+pos-second_fm': False,

@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import os import os
import yaml import re
from typing import Dict from typing import Dict
from typing import List from typing import List
import jieba.posseg as psg import jieba.posseg as psg
import numpy as np import numpy as np
import paddle import paddle
import yaml
from g2pM import G2pM from g2pM import G2pM
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
from pypinyin import load_phrases_dict from pypinyin import load_phrases_dict
@ -58,8 +58,12 @@ def insert_after_character(lst, item):
class Polyphonic(): class Polyphonic():
def __init__(self): def __init__(self):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), with open(
'polyphonic.yaml'), 'r',encoding='utf-8') as polyphonic_file: os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'polyphonic.yaml'),
'r',
encoding='utf-8') as polyphonic_file:
# 解析yaml # 解析yaml
polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader) polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader)
self.polyphonic_words = polyphonic_dict["polyphonic"] self.polyphonic_words = polyphonic_dict["polyphonic"]
@ -71,6 +75,7 @@ class Polyphonic():
# 否则返回原读音 # 否则返回原读音
return pinyin return pinyin
class Frontend(): class Frontend():
def __init__(self, def __init__(self,
g2p_model="g2pW", g2p_model="g2pW",
@ -88,7 +93,8 @@ class Frontend():
elif self.g2p_model == "g2pW": elif self.g2p_model == "g2pW":
self.corrector = Polyphonic() self.corrector = Polyphonic()
self.g2pM_model = G2pM() self.g2pM_model = G2pM()
self.g2pW_model = G2PWOnnxConverter(style='pinyin', enable_non_tradional_chinese=True) self.g2pW_model = G2PWOnnxConverter(
style='pinyin', enable_non_tradional_chinese=True)
self.pinyin2phone = generate_lexicon( self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False) with_tone=True, with_erhua=False)
@ -199,13 +205,15 @@ class Frontend():
continue continue
word_pinyins = pinyins[pre_word_length:now_word_length] word_pinyins = pinyins[pre_word_length:now_word_length]
# 矫正发音 # 矫正发音
word_pinyins = self.corrector.correct_pronunciation(word,word_pinyins) word_pinyins = self.corrector.correct_pronunciation(
word, word_pinyins)
for pinyin, char in zip(word_pinyins, word): for pinyin, char in zip(word_pinyins, word):
if pinyin == None: if pinyin is None:
pinyin = char pinyin = char
pinyin = pinyin.replace("u:", "v") pinyin = pinyin.replace("u:", "v")
if pinyin in self.pinyin2phone: if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ") initial_final_list = self.pinyin2phone[
pinyin].split(" ")
if len(initial_final_list) == 2: if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0]) sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1]) sub_finals.append(initial_final_list[1])

Loading…
Cancel
Save