From 225737d4e3318fb5a87bd86ae018aaa7e9e46975 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 25 Apr 2023 15:07:30 +0800 Subject: [PATCH] [s2t] fix cli args to config (#3194) * fix cli args to config * fix train cli --- paddlespeech/dataset/s2t/build_vocab.py | 3 ++ paddlespeech/dataset/s2t/format_data.py | 1 + paddlespeech/s2t/exps/u2/bin/alignment.py | 24 ++------- paddlespeech/s2t/exps/u2/bin/export.py | 20 ++------ paddlespeech/s2t/exps/u2/bin/quant.py | 30 ++---------- paddlespeech/s2t/exps/u2/bin/test.py | 23 ++------- paddlespeech/s2t/exps/u2/bin/test_wav.py | 25 +--------- paddlespeech/s2t/exps/u2/bin/train.py | 18 ++----- paddlespeech/s2t/training/cli.py | 59 ++++++++++++++++++++++- 9 files changed, 83 insertions(+), 120 deletions(-) diff --git a/paddlespeech/dataset/s2t/build_vocab.py b/paddlespeech/dataset/s2t/build_vocab.py index dd5f62081..081edf3d3 100755 --- a/paddlespeech/dataset/s2t/build_vocab.py +++ b/paddlespeech/dataset/s2t/build_vocab.py @@ -74,6 +74,9 @@ def build_vocab(manifest_paths="", spm_vocab_size=0, spm_model_prefix="", spm_character_coverage=0.9995): + manifest_paths = [manifest_paths] if isinstance(manifest_paths, + str) else manifest_paths + fout = open(vocab_path, 'w', encoding='utf-8') fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC fout.write(UNK + '\n') # must be 1 diff --git a/paddlespeech/dataset/s2t/format_data.py b/paddlespeech/dataset/s2t/format_data.py index dcff66eac..addd6fdc9 100755 --- a/paddlespeech/dataset/s2t/format_data.py +++ b/paddlespeech/dataset/s2t/format_data.py @@ -58,6 +58,7 @@ def format_data( unit_type="char", vocab_path="examples/librispeech/data/vocab.txt", spm_model_prefix=""): + manifest_paths = [manifest_paths] if isinstance(manifest_paths, str) else manifest_paths fout = open(output_path, 'w', encoding='utf-8') diff --git a/paddlespeech/s2t/exps/u2/bin/alignment.py b/paddlespeech/s2t/exps/u2/bin/alignment.py index cc2940388..64cafc484 100644 --- a/paddlespeech/s2t/exps/u2/bin/alignment.py +++ b/paddlespeech/s2t/exps/u2/bin/alignment.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Alignment for U2 model.""" -from yacs.config import CfgNode - from paddlespeech.s2t.exps.u2.model import U2Tester as Tester +from paddlespeech.s2t.training.cli import config_from_args from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.training.cli import maybe_dump_config from paddlespeech.utils.argparse import print_arguments @@ -32,26 +32,10 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save asr result to - parser.add_argument( - "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) - # https://yaml.org/type/float.html - config = CfgNode(new_allowed=True) - if args.config: - config.merge_from_file(args.config) - if args.decode_cfg: - decode_confs = CfgNode(new_allowed=True) - decode_confs.merge_from_file(args.decode_cfg) - config.decode = decode_confs - if args.opts: - config.merge_from_list(args.opts) - config.freeze() + config = config_from_args(args) print(config) - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) - + maybe_dump_config(args.dump_config, config) main(config, args) diff --git a/paddlespeech/s2t/exps/u2/bin/export.py b/paddlespeech/s2t/exps/u2/bin/export.py index 4725e5e13..de4a55a41 100644 --- a/paddlespeech/s2t/exps/u2/bin/export.py +++ b/paddlespeech/s2t/exps/u2/bin/export.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Export for U2 model.""" -from yacs.config import CfgNode - from paddlespeech.s2t.exps.u2.model import U2Tester as Tester +from paddlespeech.s2t.training.cli import config_from_args from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.training.cli import maybe_dump_config from paddlespeech.utils.argparse import print_arguments @@ -32,22 +32,10 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save jit model to - parser.add_argument( - "--export_path", type=str, help="path of the jit model to save") args = parser.parse_args() print_arguments(args, globals()) - # https://yaml.org/type/float.html - config = CfgNode(new_allowed=True) - if args.config: - config.merge_from_file(args.config) - if args.opts: - config.merge_from_list(args.opts) - config.freeze() + config = config_from_args(args) print(config) - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) - + maybe_dump_config(args.dump_config, config) main(config, args) diff --git a/paddlespeech/s2t/exps/u2/bin/quant.py b/paddlespeech/s2t/exps/u2/bin/quant.py index 6d361c5fd..73a9794fc 100755 --- a/paddlespeech/s2t/exps/u2/bin/quant.py +++ b/paddlespeech/s2t/exps/u2/bin/quant.py @@ -15,14 +15,15 @@ import paddle from kaldiio import ReadHelper from paddleslim import PTQ -from yacs.config import CfgNode from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.models.u2 import U2Model +from paddlespeech.s2t.training.cli import config_from_args from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.utility import UpdateConfig + logger = Log(__name__).getlog() @@ -173,32 +174,7 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save asr result to - parser.add_argument( - "--result_file", type=str, help="path of save the asr result") - parser.add_argument( - "--audio_scp", type=str, help="path of the input audio file") - parser.add_argument( - "--num_utts", - type=int, - default=200, - help="num utts for quant calibrition.") - parser.add_argument( - "--export_path", - type=str, - default='export.jit.quant', - help="path of the input audio file") args = parser.parse_args() - config = CfgNode(new_allowed=True) - - if args.config: - config.merge_from_file(args.config) - if args.decode_cfg: - decode_confs = CfgNode(new_allowed=True) - decode_confs.merge_from_file(args.decode_cfg) - config.decode = decode_confs - if args.opts: - config.merge_from_list(args.opts) - config.freeze() + config = config_from_args(args) main(config, args) diff --git a/paddlespeech/s2t/exps/u2/bin/test.py b/paddlespeech/s2t/exps/u2/bin/test.py index 43eeff631..ea1878620 100644 --- a/paddlespeech/s2t/exps/u2/bin/test.py +++ b/paddlespeech/s2t/exps/u2/bin/test.py @@ -14,10 +14,10 @@ """Evaluation for U2 model.""" import cProfile -from yacs.config import CfgNode - from paddlespeech.s2t.exps.u2.model import U2Tester as Tester +from paddlespeech.s2t.training.cli import config_from_args from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.training.cli import maybe_dump_config from paddlespeech.utils.argparse import print_arguments @@ -34,27 +34,12 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save asr result to - parser.add_argument( - "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) - # https://yaml.org/type/float.html - config = CfgNode(new_allowed=True) - if args.config: - config.merge_from_file(args.config) - if args.decode_cfg: - decode_confs = CfgNode(new_allowed=True) - decode_confs.merge_from_file(args.decode_cfg) - config.decode = decode_confs - if args.opts: - config.merge_from_list(args.opts) - config.freeze() + config = config_from_args(args) print(config) - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) + maybe_dump_config(args.dump_config, config) # Setting for profiling pr = cProfile.Profile() diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 0df443193..a6228a128 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -16,15 +16,14 @@ import os import sys from pathlib import Path -import distutils import numpy as np import paddle import soundfile -from yacs.config import CfgNode from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.models.u2 import U2Model +from paddlespeech.s2t.training.cli import config_from_args from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.utility import UpdateConfig @@ -125,27 +124,7 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save asr result to - parser.add_argument( - "--result_file", type=str, help="path of save the asr result") - parser.add_argument( - "--audio_file", type=str, help="path of the input audio file") - parser.add_argument( - "--debug", - type=distutils.util.strtobool, - default=False, - help="for debug.") args = parser.parse_args() - config = CfgNode(new_allowed=True) - - if args.config: - config.merge_from_file(args.config) - if args.decode_cfg: - decode_confs = CfgNode(new_allowed=True) - decode_confs.merge_from_file(args.decode_cfg) - config.decode = decode_confs - if args.opts: - config.merge_from_list(args.opts) - config.freeze() + config = config_from_args(args) main(config, args) diff --git a/paddlespeech/s2t/exps/u2/bin/train.py b/paddlespeech/s2t/exps/u2/bin/train.py index a0f503288..b52d5e90b 100644 --- a/paddlespeech/s2t/exps/u2/bin/train.py +++ b/paddlespeech/s2t/exps/u2/bin/train.py @@ -15,14 +15,12 @@ import cProfile import os -from yacs.config import CfgNode - from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer +from paddlespeech.s2t.training.cli import config_from_args from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.training.cli import maybe_dump_config from paddlespeech.utils.argparse import print_arguments -# from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer - def main_sp(config, args): exp = Trainer(config, args) @@ -39,17 +37,9 @@ if __name__ == "__main__": args = parser.parse_args() print_arguments(args, globals()) - # https://yaml.org/type/float.html - config = CfgNode(new_allowed=True) - if args.config: - config.merge_from_file(args.config) - if args.opts: - config.merge_from_list(args.opts) - config.freeze() + config = config_from_args(args) print(config) - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) + maybe_dump_config(args.dump_path, config) # Setting for profiling pr = cProfile.Profile() diff --git a/paddlespeech/s2t/training/cli.py b/paddlespeech/s2t/training/cli.py index 1b6bec8a8..741b95dff 100644 --- a/paddlespeech/s2t/training/cli.py +++ b/paddlespeech/s2t/training/cli.py @@ -13,6 +13,9 @@ # limitations under the License. import argparse +import distutils +from yacs.config import CfgNode + class ExtendAction(argparse.Action): """ @@ -68,7 +71,15 @@ def default_argument_parser(parser=None): parser.register('action', 'extend', ExtendAction) parser.add_argument( '--conf', type=open, action=LoadFromFile, help="config file.") + parser.add_argument( + "--debug", + type=distutils.util.strtobool, + default=False, + help="logging with debug mode.") + parser.add_argument( + "--dump_path", type=str, default=None, help="path to dump config file.") + # train group train_group = parser.add_argument_group( title='Train Options', description=None) train_group.add_argument( @@ -103,14 +114,35 @@ def default_argument_parser(parser=None): train_group.add_argument( "--dump-config", metavar="FILE", help="dump config to `this` file.") + # test group test_group = parser.add_argument_group( title='Test Options', description=None) - test_group.add_argument( "--decode_cfg", metavar="DECODE_CONFIG_FILE", help="decode config file.") + test_group.add_argument( + "--result_file", type=str, help="path of save the asr result") + test_group.add_argument( + "--audio_file", type=str, help="path of the input audio file") + + # quant & export + quant_group = parser.add_argument_group( + title='Quant Options', description=None) + quant_group.add_argument( + "--audio_scp", type=str, help="path of the input audio scp file") + quant_group.add_argument( + "--num_utts", + type=int, + default=200, + help="num utts for quant calibrition.") + quant_group.add_argument( + "--export_path", + type=str, + default='export.jit.quant', + help="path of the jit model to save") + # profile group profile_group = parser.add_argument_group( title='Benchmark Options', description=None) profile_group.add_argument( @@ -131,3 +163,28 @@ def default_argument_parser(parser=None): help='max iteration for benchmark.') return parser + + +def config_from_args(args): + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + + if args.config: + config.merge_from_file(args.config) + + if args.decode_cfg: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_cfg) + config.decode = decode_confs + + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + return config + + +def maybe_dump_config(dump_path, config): + if dump_path: + with open(dump_path, 'w') as f: + print(config, file=f) + print(f"save config to {dump_path}")