diff --git a/paddlespeech/t2s/frontend/g2pw/__init__.py b/paddlespeech/t2s/frontend/g2pw/__init__.py index 0eaeee5d..89b3af3c 100644 --- a/paddlespeech/t2s/frontend/g2pw/__init__.py +++ b/paddlespeech/t2s/frontend/g2pw/__init__.py @@ -1 +1 @@ -from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter +from .onnx_api import G2PWOnnxConverter diff --git a/paddlespeech/t2s/frontend/g2pw/dataset.py b/paddlespeech/t2s/frontend/g2pw/dataset.py index 98af5f46..8a1c2e0b 100644 --- a/paddlespeech/t2s/frontend/g2pw/dataset.py +++ b/paddlespeech/t2s/frontend/g2pw/dataset.py @@ -15,6 +15,10 @@ Credits This code is modified from https://github.com/GitYCC/g2pW """ +from typing import Dict +from typing import List +from typing import Tuple + import numpy as np from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map @@ -23,22 +27,17 @@ ANCHOR_CHAR = '▁' def prepare_onnx_input(tokenizer, - labels, - char2phonemes, - chars, - texts, - query_ids, - phonemes=None, - pos_tags=None, - use_mask=False, - use_char_phoneme=False, - use_pos=False, - window_size=None, - max_len=512): + labels: List[str], + char2phonemes: Dict[str, List[int]], + chars: List[str], + texts: List[str], + query_ids: List[int], + use_mask: bool=False, + window_size: int=None, + max_len: int=512) -> Dict[str, np.array]: if window_size is not None: - truncated_texts, truncated_query_ids = _truncate_texts(window_size, - texts, query_ids) - + truncated_texts, truncated_query_ids = _truncate_texts( + window_size=window_size, texts=texts, query_ids=query_ids) input_ids = [] token_type_ids = [] attention_masks = [] @@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer, query_id = (truncated_query_ids if window_size else query_ids)[idx] try: - tokens, text2token, token2text = tokenize_and_map(tokenizer, text) + tokens, text2token, token2text = tokenize_and_map( + tokenizer=tokenizer, text=text) except Exception: print(f'warning: text "{text}" is invalid') return {} text, query_id, tokens, text2token, token2text = _truncate( - max_len, text, query_id, tokens, text2token, token2text) + max_len=max_len, + text=text, + query_id=query_id, + tokens=tokens, + text2token=text2token, + token2text=token2text) processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] @@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer, return outputs -def _truncate_texts(window_size, texts, query_ids): +def _truncate_texts(window_size: int, texts: List[str], + query_ids: List[int]) -> Tuple[List[str], List[int]]: truncated_texts = [] truncated_query_ids = [] for text, query_id in zip(texts, query_ids): @@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids): return truncated_texts, truncated_query_ids -def _truncate(max_len, text, query_id, tokens, text2token, token2text): +def _truncate(max_len: int, + text: str, + query_id: int, + tokens: List[str], + text2token: List[int], + token2text: List[Tuple[int]]): truncate_len = max_len - 2 if len(tokens) <= truncate_len: return (text, query_id, tokens, text2token, token2text) @@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text): ], [(s - start, e - start) for s, e in token2text[token_start:token_end]]) -def prepare_data(sent_path, lb_path=None): - raw_texts = open(sent_path).read().rstrip().split('\n') - query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts] - texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts] - if lb_path is None: - return texts, query_ids - else: - phonemes = open(lb_path).read().rstrip().split('\n') - return texts, query_ids, phonemes - - -def get_phoneme_labels(polyphonic_chars): +def get_phoneme_labels(polyphonic_chars: List[List[str]] + ) -> Tuple[List[str], Dict[str, List[int]]]: labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) char2phonemes = {} for char, phoneme in polyphonic_chars: @@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars): return labels, char2phonemes -def get_char_phoneme_labels(polyphonic_chars): +def get_char_phoneme_labels(polyphonic_chars: List[List[str]] + ) -> Tuple[List[str], Dict[str, List[int]]]: labels = sorted( list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars]))) char2phonemes = {} diff --git a/paddlespeech/t2s/frontend/g2pw/onnx_api.py b/paddlespeech/t2s/frontend/g2pw/onnx_api.py index 180e8ae1..ad32c405 100644 --- a/paddlespeech/t2s/frontend/g2pw/onnx_api.py +++ b/paddlespeech/t2s/frontend/g2pw/onnx_api.py @@ -17,6 +17,10 @@ Credits """ import json import os +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple import numpy as np import onnxruntime @@ -37,7 +41,8 @@ from paddlespeech.utils.env import MODEL_HOME model_version = '1.1' -def predict(session, onnx_input, labels): +def predict(session, onnx_input: Dict[str, Any], + labels: List[str]) -> Tuple[List[str], List[float]]: all_preds = [] all_confidences = [] probs = session.run([], { @@ -61,10 +66,10 @@ def predict(session, onnx_input, labels): class G2PWOnnxConverter: def __init__(self, - model_dir=MODEL_HOME, - style='bopomofo', - model_source=None, - enable_non_tradional_chinese=False): + model_dir: os.PathLike=MODEL_HOME, + style: str='bopomofo', + model_source: str=None, + enable_non_tradional_chinese: bool=False): uncompress_path = download_and_decompress( g2pw_onnx_models['G2PWModel'][model_version], model_dir) @@ -76,7 +81,8 @@ class G2PWOnnxConverter: os.path.join(uncompress_path, 'g2pW.onnx'), sess_options=sess_options) self.config = load_config( - os.path.join(uncompress_path, 'config.py'), use_default=True) + config_path=os.path.join(uncompress_path, 'config.py'), + use_default=True) self.model_source = model_source if model_source else self.config.model_source self.enable_opencc = enable_non_tradional_chinese @@ -103,9 +109,9 @@ class G2PWOnnxConverter: .strip().split('\n') ] self.labels, self.char2phonemes = get_char_phoneme_labels( - self.polyphonic_chars + polyphonic_chars=self.polyphonic_chars ) if self.config.use_char_phoneme else get_phoneme_labels( - self.polyphonic_chars) + polyphonic_chars=self.polyphonic_chars) self.chars = sorted(list(self.char2phonemes.keys())) @@ -146,7 +152,7 @@ class G2PWOnnxConverter: if self.enable_opencc: self.cc = OpenCC('s2tw') - def _convert_bopomofo_to_pinyin(self, bopomofo): + def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] assert tone in '12345' component = self.bopomofo_convert_dict.get(bopomofo[:-1]) @@ -156,7 +162,7 @@ class G2PWOnnxConverter: print(f'Warning: "{bopomofo}" cannot convert to pinyin') return None - def __call__(self, sentences): + def __call__(self, sentences: List[str]) -> List[List[str]]: if isinstance(sentences, str): sentences = [sentences] @@ -169,23 +175,25 @@ class G2PWOnnxConverter: sentences = translated_sentences texts, query_ids, sent_ids, partial_results = self._prepare_data( - sentences) + sentences=sentences) if len(texts) == 0: # sentences no polyphonic words return partial_results onnx_input = prepare_onnx_input( - self.tokenizer, - self.labels, - self.char2phonemes, - self.chars, - texts, - query_ids, + tokenizer=self.tokenizer, + labels=self.labels, + char2phonemes=self.char2phonemes, + chars=self.chars, + texts=texts, + query_ids=query_ids, use_mask=self.config.use_mask, - use_char_phoneme=self.config.use_char_phoneme, window_size=None) - preds, confidences = predict(self.session_g2pW, onnx_input, self.labels) + preds, confidences = predict( + session=self.session_g2pW, + onnx_input=onnx_input, + labels=self.labels) if self.config.use_char_phoneme: preds = [pred.split(' ')[1] for pred in preds] @@ -195,7 +203,9 @@ class G2PWOnnxConverter: return results - def _prepare_data(self, sentences): + def _prepare_data( + self, sentences: List[str] + ) -> Tuple[List[str], List[int], List[int], List[List[str]]]: texts, query_ids, sent_ids, partial_results = [], [], [], [] for sent_id, sent in enumerate(sentences): # pypinyin works well for Simplified Chinese than Traditional Chinese diff --git a/paddlespeech/t2s/frontend/g2pw/utils.py b/paddlespeech/t2s/frontend/g2pw/utils.py index ad02c4c1..ba9ce51b 100644 --- a/paddlespeech/t2s/frontend/g2pw/utils.py +++ b/paddlespeech/t2s/frontend/g2pw/utils.py @@ -15,10 +15,11 @@ Credits This code is modified from https://github.com/GitYCC/g2pW """ +import os import re -def wordize_and_map(text): +def wordize_and_map(text: str): words = [] index_map_from_text_to_word = [] index_map_from_word_to_text = [] @@ -54,8 +55,8 @@ def wordize_and_map(text): return words, index_map_from_text_to_word, index_map_from_word_to_text -def tokenize_and_map(tokenizer, text): - words, text2word, word2text = wordize_and_map(text) +def tokenize_and_map(tokenizer, text: str): + words, text2word, word2text = wordize_and_map(text=text) tokens = [] index_map_from_token_to_text = [] @@ -82,7 +83,7 @@ def tokenize_and_map(tokenizer, text): return tokens, index_map_from_text_to_token, index_map_from_token_to_text -def _load_config(config_path): +def _load_config(config_path: os.PathLike): import importlib.util spec = importlib.util.spec_from_file_location('__init__', config_path) config = importlib.util.module_from_spec(spec) @@ -130,7 +131,7 @@ default_config_dict = { } -def load_config(config_path, use_default=False): +def load_config(config_path: os.PathLike, use_default: bool=False): config = _load_config(config_path) if use_default: for attr, val in default_config_dict.items():