pull/4152/head
zxcd 6 days ago
parent 8097a56be8
commit 1d3ae79afb

@ -41,6 +41,33 @@ from paddlespeech.utils.env import MODEL_HOME
model_version = '1.1'
def get_g2pw_model_path(model_dir: os.PathLike, model_version: str) -> str:
"""Resolve the G2PW ONNX model directory path.
Checks if the model file 'g2pW.onnx' exists in the expected location.
If not, downloads and decompresses the model archive
Args:
model_dir (os.PathLike): Base directory to store models (e.g., ~/.paddlespeech).
model_version (str): Model version string (e.g., '1.1').
Returns:
str: Path to the model directory containing 'g2pW.onnx'.
"""
archive_info = g2pw_onnx_models['G2PWModel'][model_version]
archive_fname = os.path.basename(
archive_info['url']) # e.g., "G2PWModel_1.1.zip"
expected_extract_name = os.path.splitext(archive_fname)[
0] # e.g., "G2PWModel_1.1"
expected_model_dir = os.path.join(model_dir, expected_extract_name)
uncompress_path = expected_model_dir
onnx_file_path = os.path.join(expected_model_dir, 'g2pW.onnx')
if not os.path.isfile(onnx_file_path):
uncompress_path = download_and_decompress(archive_info, model_dir)
return uncompress_path
def predict(session, onnx_input: Dict[str, Any],
labels: List[str]) -> Tuple[List[str], List[float]]:
all_preds = []
@ -70,8 +97,9 @@ class G2PWOnnxConverter:
style: str='bopomofo',
model_source: str=None,
enable_non_tradional_chinese: bool=False):
uncompress_path = download_and_decompress(
g2pw_onnx_models['G2PWModel'][model_version], model_dir)
# uncompress_path = download_and_decompress(
# g2pw_onnx_models['G2PWModel'][model_version], model_dir)
uncompress_path = get_g2pw_model_path(model_dir, model_version)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

Loading…
Cancel
Save