|
|
|
|
@ -41,6 +41,33 @@ from paddlespeech.utils.env import MODEL_HOME
|
|
|
|
|
model_version = '1.1'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_g2pw_model_path(model_dir: os.PathLike, model_version: str) -> str:
|
|
|
|
|
"""Resolve the G2PW ONNX model directory path.
|
|
|
|
|
|
|
|
|
|
Checks if the model file 'g2pW.onnx' exists in the expected location.
|
|
|
|
|
If not, downloads and decompresses the model archive
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_dir (os.PathLike): Base directory to store models (e.g., ~/.paddlespeech).
|
|
|
|
|
model_version (str): Model version string (e.g., '1.1').
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: Path to the model directory containing 'g2pW.onnx'.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
archive_info = g2pw_onnx_models['G2PWModel'][model_version]
|
|
|
|
|
archive_fname = os.path.basename(
|
|
|
|
|
archive_info['url']) # e.g., "G2PWModel_1.1.zip"
|
|
|
|
|
expected_extract_name = os.path.splitext(archive_fname)[
|
|
|
|
|
0] # e.g., "G2PWModel_1.1"
|
|
|
|
|
expected_model_dir = os.path.join(model_dir, expected_extract_name)
|
|
|
|
|
uncompress_path = expected_model_dir
|
|
|
|
|
onnx_file_path = os.path.join(expected_model_dir, 'g2pW.onnx')
|
|
|
|
|
if not os.path.isfile(onnx_file_path):
|
|
|
|
|
uncompress_path = download_and_decompress(archive_info, model_dir)
|
|
|
|
|
return uncompress_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(session, onnx_input: Dict[str, Any],
|
|
|
|
|
labels: List[str]) -> Tuple[List[str], List[float]]:
|
|
|
|
|
all_preds = []
|
|
|
|
|
@ -70,8 +97,9 @@ class G2PWOnnxConverter:
|
|
|
|
|
style: str='bopomofo',
|
|
|
|
|
model_source: str=None,
|
|
|
|
|
enable_non_tradional_chinese: bool=False):
|
|
|
|
|
uncompress_path = download_and_decompress(
|
|
|
|
|
g2pw_onnx_models['G2PWModel'][model_version], model_dir)
|
|
|
|
|
# uncompress_path = download_and_decompress(
|
|
|
|
|
# g2pw_onnx_models['G2PWModel'][model_version], model_dir)
|
|
|
|
|
uncompress_path = get_g2pw_model_path(model_dir, model_version)
|
|
|
|
|
|
|
|
|
|
sess_options = onnxruntime.SessionOptions()
|
|
|
|
|
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
|
|