[s2t] fix cli args to config (#3194)

* fix cli args to config

* fix train cli
pull/3202/head
Hui Zhang 2 years ago committed by GitHub
parent e3dcfa8815
commit 225737d4e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -74,6 +74,9 @@ def build_vocab(manifest_paths="",
spm_vocab_size=0,
spm_model_prefix="",
spm_character_coverage=0.9995):
manifest_paths = [manifest_paths] if isinstance(manifest_paths,
str) else manifest_paths
fout = open(vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1

@ -58,6 +58,7 @@ def format_data(
unit_type="char",
vocab_path="examples/librispeech/data/vocab.txt",
spm_model_prefix=""):
manifest_paths = [manifest_paths] if isinstance(manifest_paths, str) else manifest_paths
fout = open(output_path, 'w', encoding='utf-8')

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alignment for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments
@ -32,26 +32,10 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_config, config)
main(config, args)

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments
@ -32,22 +32,10 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_config, config)
main(config, args)

@ -15,14 +15,15 @@
import paddle
from kaldiio import ReadHelper
from paddleslim import PTQ
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
@ -173,32 +174,7 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_scp", type=str, help="path of the input audio file")
parser.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
parser.add_argument(
"--export_path",
type=str,
default='export.jit.quant',
help="path of the input audio file")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
main(config, args)

@ -14,10 +14,10 @@
"""Evaluation for U2 model."""
import cProfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments
@ -34,27 +34,12 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_config, config)
# Setting for profiling
pr = cProfile.Profile()

@ -16,15 +16,14 @@ import os
import sys
from pathlib import Path
import distutils
import numpy as np
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
@ -125,27 +124,7 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
main(config, args)

@ -15,14 +15,12 @@
import cProfile
import os
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments
# from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer
def main_sp(config, args):
exp = Trainer(config, args)
@ -39,17 +37,9 @@ if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_path, config)
# Setting for profiling
pr = cProfile.Profile()

@ -13,6 +13,9 @@
# limitations under the License.
import argparse
import distutils
from yacs.config import CfgNode
class ExtendAction(argparse.Action):
"""
@ -68,7 +71,15 @@ def default_argument_parser(parser=None):
parser.register('action', 'extend', ExtendAction)
parser.add_argument(
'--conf', type=open, action=LoadFromFile, help="config file.")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="logging with debug mode.")
parser.add_argument(
"--dump_path", type=str, default=None, help="path to dump config file.")
# train group
train_group = parser.add_argument_group(
title='Train Options', description=None)
train_group.add_argument(
@ -103,14 +114,35 @@ def default_argument_parser(parser=None):
train_group.add_argument(
"--dump-config", metavar="FILE", help="dump config to `this` file.")
# test group
test_group = parser.add_argument_group(
title='Test Options', description=None)
test_group.add_argument(
"--decode_cfg",
metavar="DECODE_CONFIG_FILE",
help="decode config file.")
test_group.add_argument(
"--result_file", type=str, help="path of save the asr result")
test_group.add_argument(
"--audio_file", type=str, help="path of the input audio file")
# quant & export
quant_group = parser.add_argument_group(
title='Quant Options', description=None)
quant_group.add_argument(
"--audio_scp", type=str, help="path of the input audio scp file")
quant_group.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
quant_group.add_argument(
"--export_path",
type=str,
default='export.jit.quant',
help="path of the jit model to save")
# profile group
profile_group = parser.add_argument_group(
title='Benchmark Options', description=None)
profile_group.add_argument(
@ -131,3 +163,28 @@ def default_argument_parser(parser=None):
help='max iteration for benchmark.')
return parser
def config_from_args(args):
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
return config
def maybe_dump_config(dump_path, config):
if dump_path:
with open(dump_path, 'w') as f:
print(config, file=f)
print(f"save config to {dump_path}")

Loading…
Cancel
Save