[s2t] fix cli args to config (#3194)

* fix cli args to config

* fix train cli
pull/3202/head
Hui Zhang 2 years ago committed by GitHub
parent e3dcfa8815
commit 225737d4e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -74,6 +74,9 @@ def build_vocab(manifest_paths="",
spm_vocab_size=0, spm_vocab_size=0,
spm_model_prefix="", spm_model_prefix="",
spm_character_coverage=0.9995): spm_character_coverage=0.9995):
manifest_paths = [manifest_paths] if isinstance(manifest_paths,
str) else manifest_paths
fout = open(vocab_path, 'w', encoding='utf-8') fout = open(vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1 fout.write(UNK + '\n') # <unk> must be 1

@ -58,6 +58,7 @@ def format_data(
unit_type="char", unit_type="char",
vocab_path="examples/librispeech/data/vocab.txt", vocab_path="examples/librispeech/data/vocab.txt",
spm_model_prefix=""): spm_model_prefix=""):
manifest_paths = [manifest_paths] if isinstance(manifest_paths, str) else manifest_paths
fout = open(output_path, 'w', encoding='utf-8') fout = open(output_path, 'w', encoding='utf-8')

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Alignment for U2 model.""" """Alignment for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments from paddlespeech.utils.argparse import print_arguments
@ -32,26 +32,10 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
# https://yaml.org/type/float.html config = config_from_args(args)
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config) print(config)
if args.dump_config: maybe_dump_config(args.dump_config, config)
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args) main(config, args)

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Export for U2 model.""" """Export for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments from paddlespeech.utils.argparse import print_arguments
@ -32,22 +32,10 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
# https://yaml.org/type/float.html config = config_from_args(args)
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config) print(config)
if args.dump_config: maybe_dump_config(args.dump_config, config)
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args) main(config, args)

@ -15,14 +15,15 @@
import paddle import paddle
from kaldiio import ReadHelper from kaldiio import ReadHelper
from paddleslim import PTQ from paddleslim import PTQ
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -173,32 +174,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_scp", type=str, help="path of the input audio file")
parser.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
parser.add_argument(
"--export_path",
type=str,
default='export.jit.quant',
help="path of the input audio file")
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = config_from_args(args)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
main(config, args) main(config, args)

@ -14,10 +14,10 @@
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import cProfile import cProfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments from paddlespeech.utils.argparse import print_arguments
@ -34,27 +34,12 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
# https://yaml.org/type/float.html config = config_from_args(args)
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config) print(config)
if args.dump_config: maybe_dump_config(args.dump_config, config)
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling # Setting for profiling
pr = cProfile.Profile() pr = cProfile.Profile()

@ -16,15 +16,14 @@ import os
import sys import sys
from pathlib import Path from pathlib import Path
import distutils
import numpy as np import numpy as np
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
@ -125,27 +124,7 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = config_from_args(args)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
main(config, args) main(config, args)

@ -15,14 +15,12 @@
import cProfile import cProfile
import os import os
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments from paddlespeech.utils.argparse import print_arguments
# from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer
def main_sp(config, args): def main_sp(config, args):
exp = Trainer(config, args) exp = Trainer(config, args)
@ -39,17 +37,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
# https://yaml.org/type/float.html config = config_from_args(args)
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config) print(config)
if args.dump_config: maybe_dump_config(args.dump_path, config)
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling # Setting for profiling
pr = cProfile.Profile() pr = cProfile.Profile()

@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import distutils
from yacs.config import CfgNode
class ExtendAction(argparse.Action): class ExtendAction(argparse.Action):
""" """
@ -68,7 +71,15 @@ def default_argument_parser(parser=None):
parser.register('action', 'extend', ExtendAction) parser.register('action', 'extend', ExtendAction)
parser.add_argument( parser.add_argument(
'--conf', type=open, action=LoadFromFile, help="config file.") '--conf', type=open, action=LoadFromFile, help="config file.")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="logging with debug mode.")
parser.add_argument(
"--dump_path", type=str, default=None, help="path to dump config file.")
# train group
train_group = parser.add_argument_group( train_group = parser.add_argument_group(
title='Train Options', description=None) title='Train Options', description=None)
train_group.add_argument( train_group.add_argument(
@ -103,14 +114,35 @@ def default_argument_parser(parser=None):
train_group.add_argument( train_group.add_argument(
"--dump-config", metavar="FILE", help="dump config to `this` file.") "--dump-config", metavar="FILE", help="dump config to `this` file.")
# test group
test_group = parser.add_argument_group( test_group = parser.add_argument_group(
title='Test Options', description=None) title='Test Options', description=None)
test_group.add_argument( test_group.add_argument(
"--decode_cfg", "--decode_cfg",
metavar="DECODE_CONFIG_FILE", metavar="DECODE_CONFIG_FILE",
help="decode config file.") help="decode config file.")
test_group.add_argument(
"--result_file", type=str, help="path of save the asr result")
test_group.add_argument(
"--audio_file", type=str, help="path of the input audio file")
# quant & export
quant_group = parser.add_argument_group(
title='Quant Options', description=None)
quant_group.add_argument(
"--audio_scp", type=str, help="path of the input audio scp file")
quant_group.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
quant_group.add_argument(
"--export_path",
type=str,
default='export.jit.quant',
help="path of the jit model to save")
# profile group
profile_group = parser.add_argument_group( profile_group = parser.add_argument_group(
title='Benchmark Options', description=None) title='Benchmark Options', description=None)
profile_group.add_argument( profile_group.add_argument(
@ -131,3 +163,28 @@ def default_argument_parser(parser=None):
help='max iteration for benchmark.') help='max iteration for benchmark.')
return parser return parser
def config_from_args(args):
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
return config
def maybe_dump_config(dump_path, config):
if dump_path:
with open(dump_path, 'w') as f:
print(config, file=f)
print(f"save config to {dump_path}")

Loading…
Cancel
Save