add typehint for g2pw (#2390)

pull/2406/head
TianYuan 2 years ago committed by GitHub
parent 68c2ec7563
commit eac362057c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1 +1 @@
from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter
from .onnx_api import G2PWOnnxConverter

@ -15,6 +15,10 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np
from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
@ -23,22 +27,17 @@ ANCHOR_CHAR = '▁'
def prepare_onnx_input(tokenizer,
labels,
char2phonemes,
chars,
texts,
query_ids,
phonemes=None,
pos_tags=None,
use_mask=False,
use_char_phoneme=False,
use_pos=False,
window_size=None,
max_len=512):
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool=False,
window_size: int=None,
max_len: int=512) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(window_size,
texts, query_ids)
truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids)
input_ids = []
token_type_ids = []
attention_masks = []
@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer,
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
tokens, text2token, token2text = tokenize_and_map(
tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len, text, query_id, tokens, text2token, token2text)
max_len=max_len,
text=text,
query_id=query_id,
tokens=tokens,
text2token=text2token,
token2text=token2text)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer,
return outputs
def _truncate_texts(window_size, texts, query_ids):
def _truncate_texts(window_size: int, texts: List[str],
query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids):
return truncated_texts, truncated_query_ids
def _truncate(max_len, text, query_id, tokens, text2token, token2text):
def _truncate(max_len: int,
text: str,
query_id: int,
tokens: List[str],
text2token: List[int],
token2text: List[Tuple[int]]):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
def prepare_data(sent_path, lb_path=None):
raw_texts = open(sent_path).read().rstrip().split('\n')
query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts]
texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts]
if lb_path is None:
return texts, query_ids
else:
phonemes = open(lb_path).read().rstrip().split('\n')
return texts, query_ids, phonemes
def get_phoneme_labels(polyphonic_chars):
def get_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars):
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {}

@ -17,6 +17,10 @@ Credits
"""
import json
import os
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np
import onnxruntime
@ -37,7 +41,8 @@ from paddlespeech.utils.env import MODEL_HOME
model_version = '1.1'
def predict(session, onnx_input, labels):
def predict(session, onnx_input: Dict[str, Any],
labels: List[str]) -> Tuple[List[str], List[float]]:
all_preds = []
all_confidences = []
probs = session.run([], {
@ -61,10 +66,10 @@ def predict(session, onnx_input, labels):
class G2PWOnnxConverter:
def __init__(self,
model_dir=MODEL_HOME,
style='bopomofo',
model_source=None,
enable_non_tradional_chinese=False):
model_dir: os.PathLike=MODEL_HOME,
style: str='bopomofo',
model_source: str=None,
enable_non_tradional_chinese: bool=False):
uncompress_path = download_and_decompress(
g2pw_onnx_models['G2PWModel'][model_version], model_dir)
@ -76,7 +81,8 @@ class G2PWOnnxConverter:
os.path.join(uncompress_path, 'g2pW.onnx'),
sess_options=sess_options)
self.config = load_config(
os.path.join(uncompress_path, 'config.py'), use_default=True)
config_path=os.path.join(uncompress_path, 'config.py'),
use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
@ -103,9 +109,9 @@ class G2PWOnnxConverter:
.strip().split('\n')
]
self.labels, self.char2phonemes = get_char_phoneme_labels(
self.polyphonic_chars
polyphonic_chars=self.polyphonic_chars
) if self.config.use_char_phoneme else get_phoneme_labels(
self.polyphonic_chars)
polyphonic_chars=self.polyphonic_chars)
self.chars = sorted(list(self.char2phonemes.keys()))
@ -146,7 +152,7 @@ class G2PWOnnxConverter:
if self.enable_opencc:
self.cc = OpenCC('s2tw')
def _convert_bopomofo_to_pinyin(self, bopomofo):
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
assert tone in '12345'
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
@ -156,7 +162,7 @@ class G2PWOnnxConverter:
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
def __call__(self, sentences):
def __call__(self, sentences: List[str]) -> List[List[str]]:
if isinstance(sentences, str):
sentences = [sentences]
@ -169,23 +175,25 @@ class G2PWOnnxConverter:
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(
sentences)
sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
onnx_input = prepare_onnx_input(
self.tokenizer,
self.labels,
self.char2phonemes,
self.chars,
texts,
query_ids,
tokenizer=self.tokenizer,
labels=self.labels,
char2phonemes=self.char2phonemes,
chars=self.chars,
texts=texts,
query_ids=query_ids,
use_mask=self.config.use_mask,
use_char_phoneme=self.config.use_char_phoneme,
window_size=None)
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
preds, confidences = predict(
session=self.session_g2pW,
onnx_input=onnx_input,
labels=self.labels)
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]
@ -195,7 +203,9 @@ class G2PWOnnxConverter:
return results
def _prepare_data(self, sentences):
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese

@ -15,10 +15,11 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import re
def wordize_and_map(text):
def wordize_and_map(text: str):
words = []
index_map_from_text_to_word = []
index_map_from_word_to_text = []
@ -54,8 +55,8 @@ def wordize_and_map(text):
return words, index_map_from_text_to_word, index_map_from_word_to_text
def tokenize_and_map(tokenizer, text):
words, text2word, word2text = wordize_and_map(text)
def tokenize_and_map(tokenizer, text: str):
words, text2word, word2text = wordize_and_map(text=text)
tokens = []
index_map_from_token_to_text = []
@ -82,7 +83,7 @@ def tokenize_and_map(tokenizer, text):
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
def _load_config(config_path):
def _load_config(config_path: os.PathLike):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
config = importlib.util.module_from_spec(spec)
@ -130,7 +131,7 @@ default_config_dict = {
}
def load_config(config_path, use_default=False):
def load_config(config_path: os.PathLike, use_default: bool=False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():

Loading…
Cancel
Save