Merge pull request #1361 from Jackwaterveg/setup

[Setup]refactor the version
pull/1364/head
Hui Zhang 3 years ago committed by GitHub
commit c8a5d1db78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,5 +13,3 @@
# limitations under the License. # limitations under the License.
from .backends import * from .backends import *
from .features import * from .features import *
__version__ = '0.1.0'

@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = '0.1.1'

@ -17,7 +17,6 @@ import io
import os import os
import subprocess as sp import subprocess as sp
import sys import sys
import paddlespeech
from pathlib import Path from pathlib import Path
from setuptools import Command from setuptools import Command
@ -28,6 +27,8 @@ from setuptools.command.install import install
HERE = Path(os.path.abspath(os.path.dirname(__file__))) HERE = Path(os.path.abspath(os.path.dirname(__file__)))
VERSION = '0.1.1'
requirements = { requirements = {
"install": [ "install": [
"editdistance", "editdistance",
@ -83,6 +84,24 @@ requirements = {
} }
def write_version_py(filename='paddlespeech/__init__.py'):
import paddlespeech
if hasattr(paddlespeech,
"__version__") and paddlespeech.__version__ == VERSION:
return
with open(filename, "a") as f:
f.write(f"\n__version__ = '{VERSION}'\n")
def remove_version_py(filename='paddlespeech/__init__.py'):
with open(filename, "r") as f:
lines = f.readlines()
with open(filename, "w") as f:
for line in lines:
if "__version__" not in line:
f.write(line)
@contextlib.contextmanager @contextlib.contextmanager
def pushd(new_dir): def pushd(new_dir):
old_dir = os.getcwd() old_dir = os.getcwd()
@ -170,10 +189,12 @@ class UploadCommand(Command):
sys.exit() sys.exit()
write_version_py()
setup_info = dict( setup_info = dict(
# Metadata # Metadata
name='paddlespeech', name='paddlespeech',
version=paddlespeech.__version__, version=VERSION,
author='PaddlePaddle Speech and Language Team', author='PaddlePaddle Speech and Language Team',
author_email='paddlesl@baidu.com', author_email='paddlesl@baidu.com',
url='https://github.com/PaddlePaddle/PaddleSpeech', url='https://github.com/PaddlePaddle/PaddleSpeech',
@ -236,3 +257,5 @@ setup_info = dict(
}) })
setup(**setup_info) setup(**setup_info)
remove_version_py()

@ -13,14 +13,33 @@
# limitations under the License. # limitations under the License.
import setuptools import setuptools
import paddleaudio
# set the version here # set the version here
version = paddleaudio.__version__ VERSION = '0.1.0'
def write_version_py(filename='paddleaudio/__init__.py'):
import paddleaudio
if hasattr(paddleaudio,
"__version__") and paddleaudio.__version__ == VERSION:
return
with open(filename, "a") as f:
f.write(f"\n__version__ = '{VERSION}'\n")
def remove_version_py(filename='paddleaudio/__init__.py'):
with open(filename, "r") as f:
lines = f.readlines()
with open(filename, "w") as f:
for line in lines:
if "__version__" not in line:
f.write(line)
write_version_py()
setuptools.setup( setuptools.setup(
name="paddleaudio", name="paddleaudio",
version=version, version=VERSION,
author="", author="",
author_email="", author_email="",
description="PaddleAudio, in development", description="PaddleAudio, in development",
@ -41,3 +60,5 @@ setuptools.setup(
'soundfile >= 0.9.0', 'soundfile >= 0.9.0',
'colorlog', 'colorlog',
], ) ], )
remove_version_py()

@ -1,21 +1,19 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
''' '''
Merge training configs into a single inference config. Merge training configs into a single inference config.
The single inference config is for CLI, which only takes a single config to do inferencing. The single inference config is for CLI, which only takes a single config to do inferencing.
The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file. The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file.
''' '''
import yaml
import json
import os
import argparse import argparse
import json
import math import math
import os
from contextlib import redirect_stdout
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.frontend.utility import load_dict from paddlespeech.s2t.frontend.utility import load_dict
from contextlib import redirect_stdout
def save(save_path, config): def save(save_path, config):
@ -29,18 +27,21 @@ def load(save_path):
config.merge_from_file(save_path) config.merge_from_file(save_path)
return config return config
def load_json(json_path): def load_json(json_path):
with open(json_path) as f: with open(json_path) as f:
json_content = json.load(f) json_content = json.load(f)
return json_content return json_content
def remove_config_part(config, key_list): def remove_config_part(config, key_list):
if len(key_list) == 0: if len(key_list) == 0:
return return
for i in range(len(key_list) -1): for i in range(len(key_list) - 1):
config = config[key_list[i]] config = config[key_list[i]]
config.pop(key_list[-1]) config.pop(key_list[-1])
def load_cmvn_from_json(cmvn_stats): def load_cmvn_from_json(cmvn_stats):
means = cmvn_stats['mean_stat'] means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat'] variance = cmvn_stats['var_stat']
@ -51,17 +52,17 @@ def load_cmvn_from_json(cmvn_stats):
if variance[i] < 1.0e-20: if variance[i] < 1.0e-20:
variance[i] = 1.0e-20 variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i]) variance[i] = 1.0 / math.sqrt(variance[i])
cmvn_stats = {"mean":means, "istd":variance} cmvn_stats = {"mean": means, "istd": variance}
return cmvn_stats return cmvn_stats
def merge_configs( def merge_configs(
conf_path = "conf/conformer.yaml", conf_path="conf/conformer.yaml",
preprocess_path = "conf/preprocess.yaml", preprocess_path="conf/preprocess.yaml",
decode_path = "conf/tuning/decode.yaml", decode_path="conf/tuning/decode.yaml",
vocab_path = "data/vocab.txt", vocab_path="data/vocab.txt",
cmvn_path = "data/mean_std.json", cmvn_path="data/mean_std.json",
save_path = "conf/conformer_infer.yaml", save_path="conf/conformer_infer.yaml", ):
):
# Load the configs # Load the configs
config = load(conf_path) config = load(conf_path)
@ -72,17 +73,16 @@ def merge_configs(
if cmvn_path.split(".")[-1] == 'json': if cmvn_path.split(".")[-1] == 'json':
cmvn_stats = load_json(cmvn_path) cmvn_stats = load_json(cmvn_path)
if os.path.exists(preprocess_path): if os.path.exists(preprocess_path):
preprocess_config = load(preprocess_path) preprocess_config = load(preprocess_path)
for idx, process in enumerate(preprocess_config["process"]): for idx, process in enumerate(preprocess_config["process"]):
if process['type'] == "cmvn_json": if process['type'] == "cmvn_json":
preprocess_config["process"][idx][ preprocess_config["process"][idx]["cmvn_path"] = cmvn_stats
"cmvn_path"] = cmvn_stats
break break
config.preprocess_config = preprocess_config config.preprocess_config = preprocess_config
else: else:
cmvn_stats = load_cmvn_from_json(cmvn_stats) cmvn_stats = load_cmvn_from_json(cmvn_stats)
config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] config.mean_std_filepath = [{"cmvn_stats": cmvn_stats}]
config.augmentation_config = '' config.augmentation_config = ''
# the cmvn file is end with .ark # the cmvn file is end with .ark
else: else:
@ -95,7 +95,8 @@ def merge_configs(
# Remove some parts of the config # Remove some parts of the config
if os.path.exists(preprocess_path): if os.path.exists(preprocess_path):
remove_train_list = ["train_manifest", remove_train_list = [
"train_manifest",
"dev_manifest", "dev_manifest",
"test_manifest", "test_manifest",
"n_epoch", "n_epoch",
@ -124,9 +125,10 @@ def merge_configs(
"batch_size", "batch_size",
"maxlen_in", "maxlen_in",
"maxlen_out", "maxlen_out",
] ]
else: else:
remove_train_list = ["train_manifest", remove_train_list = [
"train_manifest",
"dev_manifest", "dev_manifest",
"test_manifest", "test_manifest",
"n_epoch", "n_epoch",
@ -141,43 +143,41 @@ def merge_configs(
"weight_decay", "weight_decay",
"sortagrad", "sortagrad",
"num_workers", "num_workers",
] ]
for item in remove_train_list: for item in remove_train_list:
try: try:
remove_config_part(config, [item]) remove_config_part(config, [item])
except: except:
print ( item + " " +"can not be removed") print(item + " " + "can not be removed")
# Save the config # Save the config
save(save_path, config) save(save_path, config)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(prog='Config merge', add_help=True)
prog='Config merge', add_help=True)
parser.add_argument( parser.add_argument(
'--cfg_pth', type=str, default = 'conf/transformer.yaml', help='origin config file') '--cfg_pth',
type=str,
default='conf/transformer.yaml',
help='origin config file')
parser.add_argument( parser.add_argument(
'--pre_pth', type=str, default= "conf/preprocess.yaml", help='') '--pre_pth', type=str, default="conf/preprocess.yaml", help='')
parser.add_argument( parser.add_argument(
'--dcd_pth', type=str, default= "conf/tuninig/decode.yaml", help='') '--dcd_pth', type=str, default="conf/tuninig/decode.yaml", help='')
parser.add_argument( parser.add_argument(
'--vb_pth', type=str, default= "data/lang_char/vocab.txt", help='') '--vb_pth', type=str, default="data/lang_char/vocab.txt", help='')
parser.add_argument( parser.add_argument(
'--cmvn_pth', type=str, default= "data/mean_std.json", help='') '--cmvn_pth', type=str, default="data/mean_std.json", help='')
parser.add_argument( parser.add_argument(
'--save_pth', type=str, default= "conf/transformer_infer.yaml", help='') '--save_pth', type=str, default="conf/transformer_infer.yaml", help='')
parser_args = parser.parse_args() parser_args = parser.parse_args()
merge_configs( merge_configs(
conf_path = parser_args.cfg_pth, conf_path=parser_args.cfg_pth,
decode_path = parser_args.dcd_pth, decode_path=parser_args.dcd_pth,
preprocess_path = parser_args.pre_pth, preprocess_path=parser_args.pre_pth,
vocab_path = parser_args.vb_pth, vocab_path=parser_args.vb_pth,
cmvn_path = parser_args.cmvn_pth, cmvn_path=parser_args.cmvn_pth,
save_path = parser_args.save_pth, save_path=parser_args.save_pth, )
)

Loading…
Cancel
Save