|
|
|
@ -1,38 +1,54 @@
|
|
|
|
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
Credits
|
|
|
|
|
This code is modified from https://github.com/GitYCC/g2pW
|
|
|
|
|
This code is modified from https://github.com/GitYCC/g2pW
|
|
|
|
|
"""
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
import onnxruntime
|
|
|
|
|
import numpy as np
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import onnxruntime
|
|
|
|
|
from opencc import OpenCC
|
|
|
|
|
from pypinyin import pinyin, lazy_pinyin, Style
|
|
|
|
|
from paddlenlp.transformers import BertTokenizer
|
|
|
|
|
from paddlespeech.utils.env import MODEL_HOME
|
|
|
|
|
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data,\
|
|
|
|
|
prepare_onnx_input,\
|
|
|
|
|
get_phoneme_labels,\
|
|
|
|
|
get_char_phoneme_labels
|
|
|
|
|
from paddlespeech.t2s.frontend.g2pw.utils import load_config
|
|
|
|
|
from pypinyin import pinyin
|
|
|
|
|
from pypinyin import Style
|
|
|
|
|
|
|
|
|
|
from paddlespeech.cli.utils import download_and_decompress
|
|
|
|
|
from paddlespeech.resource.pretrained_models import g2pw_onnx_models
|
|
|
|
|
from paddlespeech.t2s.frontend.g2pw.dataset import get_char_phoneme_labels
|
|
|
|
|
from paddlespeech.t2s.frontend.g2pw.dataset import get_phoneme_labels
|
|
|
|
|
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_onnx_input
|
|
|
|
|
from paddlespeech.t2s.frontend.g2pw.utils import load_config
|
|
|
|
|
from paddlespeech.utils.env import MODEL_HOME
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(session, onnx_input, labels):
|
|
|
|
|
all_preds = []
|
|
|
|
|
all_confidences = []
|
|
|
|
|
probs = session.run([],{"input_ids": onnx_input['input_ids'],
|
|
|
|
|
"token_type_ids":onnx_input['token_type_ids'],
|
|
|
|
|
"attention_mask":onnx_input['attention_masks'],
|
|
|
|
|
"phoneme_mask":onnx_input['phoneme_masks'],
|
|
|
|
|
"char_ids":onnx_input['char_ids'],
|
|
|
|
|
"position_ids":onnx_input['position_ids']})[0]
|
|
|
|
|
|
|
|
|
|
preds = np.argmax(probs,axis=1).tolist()
|
|
|
|
|
probs = session.run([], {
|
|
|
|
|
"input_ids": onnx_input['input_ids'],
|
|
|
|
|
"token_type_ids": onnx_input['token_type_ids'],
|
|
|
|
|
"attention_mask": onnx_input['attention_masks'],
|
|
|
|
|
"phoneme_mask": onnx_input['phoneme_masks'],
|
|
|
|
|
"char_ids": onnx_input['char_ids'],
|
|
|
|
|
"position_ids": onnx_input['position_ids']
|
|
|
|
|
})[0]
|
|
|
|
|
|
|
|
|
|
preds = np.argmax(probs, axis=1).tolist()
|
|
|
|
|
max_probs = []
|
|
|
|
|
for index,arr in zip(preds,probs.tolist()):
|
|
|
|
|
for index, arr in zip(preds, probs.tolist()):
|
|
|
|
|
max_probs.append(arr[index])
|
|
|
|
|
all_preds += [labels[pred] for pred in preds]
|
|
|
|
|
all_confidences += max_probs
|
|
|
|
@ -41,39 +57,69 @@ def predict(session, onnx_input, labels):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class G2PWOnnxConverter:
|
|
|
|
|
def __init__(self, model_dir = MODEL_HOME, style='bopomofo', model_source=None, enable_non_tradional_chinese=False):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
model_dir=MODEL_HOME,
|
|
|
|
|
style='bopomofo',
|
|
|
|
|
model_source=None,
|
|
|
|
|
enable_non_tradional_chinese=False):
|
|
|
|
|
if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')):
|
|
|
|
|
uncompress_path = download_and_decompress(g2pw_onnx_models['G2PWModel']['1.0'],model_dir)
|
|
|
|
|
uncompress_path = download_and_decompress(
|
|
|
|
|
g2pw_onnx_models['G2PWModel']['1.0'], model_dir)
|
|
|
|
|
|
|
|
|
|
sess_options = onnxruntime.SessionOptions()
|
|
|
|
|
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
|
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
|
|
|
|
sess_options.intra_op_num_threads = 2
|
|
|
|
|
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),sess_options=sess_options)
|
|
|
|
|
self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
|
|
|
|
|
self.session_g2pW = onnxruntime.InferenceSession(
|
|
|
|
|
os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),
|
|
|
|
|
sess_options=sess_options)
|
|
|
|
|
self.config = load_config(
|
|
|
|
|
os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
|
|
|
|
|
|
|
|
|
|
self.model_source = model_source if model_source else self.config.model_source
|
|
|
|
|
self.enable_opencc = enable_non_tradional_chinese
|
|
|
|
|
|
|
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)
|
|
|
|
|
|
|
|
|
|
polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt')
|
|
|
|
|
monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt')
|
|
|
|
|
self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')]
|
|
|
|
|
self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')]
|
|
|
|
|
self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars)
|
|
|
|
|
polyphonic_chars_path = os.path.join(model_dir,
|
|
|
|
|
'G2PWModel/POLYPHONIC_CHARS.txt')
|
|
|
|
|
monophonic_chars_path = os.path.join(model_dir,
|
|
|
|
|
'G2PWModel/MONOPHONIC_CHARS.txt')
|
|
|
|
|
self.polyphonic_chars = [
|
|
|
|
|
line.split('\t')
|
|
|
|
|
for line in open(polyphonic_chars_path, encoding='utf-8').read()
|
|
|
|
|
.strip().split('\n')
|
|
|
|
|
]
|
|
|
|
|
self.monophonic_chars = [
|
|
|
|
|
line.split('\t')
|
|
|
|
|
for line in open(monophonic_chars_path, encoding='utf-8').read()
|
|
|
|
|
.strip().split('\n')
|
|
|
|
|
]
|
|
|
|
|
self.labels, self.char2phonemes = get_char_phoneme_labels(
|
|
|
|
|
self.polyphonic_chars
|
|
|
|
|
) if self.config.use_char_phoneme else get_phoneme_labels(
|
|
|
|
|
self.polyphonic_chars)
|
|
|
|
|
|
|
|
|
|
self.chars = sorted(list(self.char2phonemes.keys()))
|
|
|
|
|
self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI']
|
|
|
|
|
|
|
|
|
|
with open(os.path.join(model_dir,'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr:
|
|
|
|
|
self.pos_tags = [
|
|
|
|
|
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
with open(
|
|
|
|
|
os.path.join(model_dir,
|
|
|
|
|
'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'),
|
|
|
|
|
'r',
|
|
|
|
|
encoding='utf-8') as fr:
|
|
|
|
|
self.bopomofo_convert_dict = json.load(fr)
|
|
|
|
|
self.style_convert_func = {
|
|
|
|
|
'bopomofo': lambda x: x,
|
|
|
|
|
'pinyin': self._convert_bopomofo_to_pinyin,
|
|
|
|
|
}[style]
|
|
|
|
|
|
|
|
|
|
with open(os.path.join(model_dir,'G2PWModel/char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr:
|
|
|
|
|
with open(
|
|
|
|
|
os.path.join(model_dir, 'G2PWModel/char_bopomofo_dict.json'),
|
|
|
|
|
'r',
|
|
|
|
|
encoding='utf-8') as fr:
|
|
|
|
|
self.char_bopomofo_dict = json.load(fr)
|
|
|
|
|
|
|
|
|
|
if self.enable_opencc:
|
|
|
|
@ -100,15 +146,23 @@ class G2PWOnnxConverter:
|
|
|
|
|
assert len(translated_sent) == len(sent)
|
|
|
|
|
translated_sentences.append(translated_sent)
|
|
|
|
|
sentences = translated_sentences
|
|
|
|
|
|
|
|
|
|
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences)
|
|
|
|
|
|
|
|
|
|
texts, query_ids, sent_ids, partial_results = self._prepare_data(
|
|
|
|
|
sentences)
|
|
|
|
|
if len(texts) == 0:
|
|
|
|
|
# sentences no polyphonic words
|
|
|
|
|
return partial_results
|
|
|
|
|
|
|
|
|
|
onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids,
|
|
|
|
|
use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme,
|
|
|
|
|
window_size=None)
|
|
|
|
|
onnx_input = prepare_onnx_input(
|
|
|
|
|
self.tokenizer,
|
|
|
|
|
self.labels,
|
|
|
|
|
self.char2phonemes,
|
|
|
|
|
self.chars,
|
|
|
|
|
texts,
|
|
|
|
|
query_ids,
|
|
|
|
|
use_mask=self.config.use_mask,
|
|
|
|
|
use_char_phoneme=self.config.use_char_phoneme,
|
|
|
|
|
window_size=None)
|
|
|
|
|
|
|
|
|
|
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
|
|
|
|
|
if self.config.use_char_phoneme:
|
|
|
|
@ -123,11 +177,12 @@ class G2PWOnnxConverter:
|
|
|
|
|
def _prepare_data(self, sentences):
|
|
|
|
|
polyphonic_chars = set(self.chars)
|
|
|
|
|
monophonic_chars_dict = {
|
|
|
|
|
char: phoneme for char, phoneme in self.monophonic_chars
|
|
|
|
|
char: phoneme
|
|
|
|
|
for char, phoneme in self.monophonic_chars
|
|
|
|
|
}
|
|
|
|
|
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
|
|
|
|
for sent_id, sent in enumerate(sentences):
|
|
|
|
|
pypinyin_result = pinyin(sent,style=Style.TONE3)
|
|
|
|
|
pypinyin_result = pinyin(sent, style=Style.TONE3)
|
|
|
|
|
partial_result = [None] * len(sent)
|
|
|
|
|
for i, char in enumerate(sent):
|
|
|
|
|
if char in polyphonic_chars:
|
|
|
|
@ -135,9 +190,10 @@ class G2PWOnnxConverter:
|
|
|
|
|
query_ids.append(i)
|
|
|
|
|
sent_ids.append(sent_id)
|
|
|
|
|
elif char in monophonic_chars_dict:
|
|
|
|
|
partial_result[i] = self.style_convert_func(monophonic_chars_dict[char])
|
|
|
|
|
partial_result[i] = self.style_convert_func(
|
|
|
|
|
monophonic_chars_dict[char])
|
|
|
|
|
elif char in self.char_bopomofo_dict:
|
|
|
|
|
partial_result[i] = pypinyin_result[i][0]
|
|
|
|
|
partial_result[i] = pypinyin_result[i][0]
|
|
|
|
|
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
|
|
|
|
partial_results.append(partial_result)
|
|
|
|
|
return texts, query_ids, sent_ids, partial_results
|
|
|
|
|