|
|
@ -17,6 +17,10 @@ Credits
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
from typing import Dict
|
|
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import onnxruntime
|
|
|
|
import onnxruntime
|
|
|
@ -37,7 +41,8 @@ from paddlespeech.utils.env import MODEL_HOME
|
|
|
|
model_version = '1.1'
|
|
|
|
model_version = '1.1'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(session, onnx_input, labels):
|
|
|
|
def predict(session, onnx_input: Dict[str, Any],
|
|
|
|
|
|
|
|
labels: List[str]) -> Tuple[List[str], List[float]]:
|
|
|
|
all_preds = []
|
|
|
|
all_preds = []
|
|
|
|
all_confidences = []
|
|
|
|
all_confidences = []
|
|
|
|
probs = session.run([], {
|
|
|
|
probs = session.run([], {
|
|
|
@ -61,10 +66,10 @@ def predict(session, onnx_input, labels):
|
|
|
|
|
|
|
|
|
|
|
|
class G2PWOnnxConverter:
|
|
|
|
class G2PWOnnxConverter:
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
model_dir=MODEL_HOME,
|
|
|
|
model_dir: os.PathLike=MODEL_HOME,
|
|
|
|
style='bopomofo',
|
|
|
|
style: str='bopomofo',
|
|
|
|
model_source=None,
|
|
|
|
model_source: str=None,
|
|
|
|
enable_non_tradional_chinese=False):
|
|
|
|
enable_non_tradional_chinese: bool=False):
|
|
|
|
uncompress_path = download_and_decompress(
|
|
|
|
uncompress_path = download_and_decompress(
|
|
|
|
g2pw_onnx_models['G2PWModel'][model_version], model_dir)
|
|
|
|
g2pw_onnx_models['G2PWModel'][model_version], model_dir)
|
|
|
|
|
|
|
|
|
|
|
@ -76,7 +81,8 @@ class G2PWOnnxConverter:
|
|
|
|
os.path.join(uncompress_path, 'g2pW.onnx'),
|
|
|
|
os.path.join(uncompress_path, 'g2pW.onnx'),
|
|
|
|
sess_options=sess_options)
|
|
|
|
sess_options=sess_options)
|
|
|
|
self.config = load_config(
|
|
|
|
self.config = load_config(
|
|
|
|
os.path.join(uncompress_path, 'config.py'), use_default=True)
|
|
|
|
config_path=os.path.join(uncompress_path, 'config.py'),
|
|
|
|
|
|
|
|
use_default=True)
|
|
|
|
|
|
|
|
|
|
|
|
self.model_source = model_source if model_source else self.config.model_source
|
|
|
|
self.model_source = model_source if model_source else self.config.model_source
|
|
|
|
self.enable_opencc = enable_non_tradional_chinese
|
|
|
|
self.enable_opencc = enable_non_tradional_chinese
|
|
|
@ -103,9 +109,9 @@ class G2PWOnnxConverter:
|
|
|
|
.strip().split('\n')
|
|
|
|
.strip().split('\n')
|
|
|
|
]
|
|
|
|
]
|
|
|
|
self.labels, self.char2phonemes = get_char_phoneme_labels(
|
|
|
|
self.labels, self.char2phonemes = get_char_phoneme_labels(
|
|
|
|
self.polyphonic_chars
|
|
|
|
polyphonic_chars=self.polyphonic_chars
|
|
|
|
) if self.config.use_char_phoneme else get_phoneme_labels(
|
|
|
|
) if self.config.use_char_phoneme else get_phoneme_labels(
|
|
|
|
self.polyphonic_chars)
|
|
|
|
polyphonic_chars=self.polyphonic_chars)
|
|
|
|
|
|
|
|
|
|
|
|
self.chars = sorted(list(self.char2phonemes.keys()))
|
|
|
|
self.chars = sorted(list(self.char2phonemes.keys()))
|
|
|
|
|
|
|
|
|
|
|
@ -146,7 +152,7 @@ class G2PWOnnxConverter:
|
|
|
|
if self.enable_opencc:
|
|
|
|
if self.enable_opencc:
|
|
|
|
self.cc = OpenCC('s2tw')
|
|
|
|
self.cc = OpenCC('s2tw')
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_bopomofo_to_pinyin(self, bopomofo):
|
|
|
|
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
|
|
|
tone = bopomofo[-1]
|
|
|
|
tone = bopomofo[-1]
|
|
|
|
assert tone in '12345'
|
|
|
|
assert tone in '12345'
|
|
|
|
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
|
|
|
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
|
|
@ -156,7 +162,7 @@ class G2PWOnnxConverter:
|
|
|
|
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
|
|
|
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, sentences):
|
|
|
|
def __call__(self, sentences: List[str]) -> List[List[str]]:
|
|
|
|
if isinstance(sentences, str):
|
|
|
|
if isinstance(sentences, str):
|
|
|
|
sentences = [sentences]
|
|
|
|
sentences = [sentences]
|
|
|
|
|
|
|
|
|
|
|
@ -169,23 +175,25 @@ class G2PWOnnxConverter:
|
|
|
|
sentences = translated_sentences
|
|
|
|
sentences = translated_sentences
|
|
|
|
|
|
|
|
|
|
|
|
texts, query_ids, sent_ids, partial_results = self._prepare_data(
|
|
|
|
texts, query_ids, sent_ids, partial_results = self._prepare_data(
|
|
|
|
sentences)
|
|
|
|
sentences=sentences)
|
|
|
|
if len(texts) == 0:
|
|
|
|
if len(texts) == 0:
|
|
|
|
# sentences no polyphonic words
|
|
|
|
# sentences no polyphonic words
|
|
|
|
return partial_results
|
|
|
|
return partial_results
|
|
|
|
|
|
|
|
|
|
|
|
onnx_input = prepare_onnx_input(
|
|
|
|
onnx_input = prepare_onnx_input(
|
|
|
|
self.tokenizer,
|
|
|
|
tokenizer=self.tokenizer,
|
|
|
|
self.labels,
|
|
|
|
labels=self.labels,
|
|
|
|
self.char2phonemes,
|
|
|
|
char2phonemes=self.char2phonemes,
|
|
|
|
self.chars,
|
|
|
|
chars=self.chars,
|
|
|
|
texts,
|
|
|
|
texts=texts,
|
|
|
|
query_ids,
|
|
|
|
query_ids=query_ids,
|
|
|
|
use_mask=self.config.use_mask,
|
|
|
|
use_mask=self.config.use_mask,
|
|
|
|
use_char_phoneme=self.config.use_char_phoneme,
|
|
|
|
|
|
|
|
window_size=None)
|
|
|
|
window_size=None)
|
|
|
|
|
|
|
|
|
|
|
|
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
|
|
|
|
preds, confidences = predict(
|
|
|
|
|
|
|
|
session=self.session_g2pW,
|
|
|
|
|
|
|
|
onnx_input=onnx_input,
|
|
|
|
|
|
|
|
labels=self.labels)
|
|
|
|
if self.config.use_char_phoneme:
|
|
|
|
if self.config.use_char_phoneme:
|
|
|
|
preds = [pred.split(' ')[1] for pred in preds]
|
|
|
|
preds = [pred.split(' ')[1] for pred in preds]
|
|
|
|
|
|
|
|
|
|
|
@ -195,7 +203,9 @@ class G2PWOnnxConverter:
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_data(self, sentences):
|
|
|
|
def _prepare_data(
|
|
|
|
|
|
|
|
self, sentences: List[str]
|
|
|
|
|
|
|
|
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
|
|
|
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
|
|
|
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
|
|
|
for sent_id, sent in enumerate(sentences):
|
|
|
|
for sent_id, sent in enumerate(sentences):
|
|
|
|
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
|
|
|
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
|
|
|