From fb853167d353f3b0e74e56dc1fbaa214fbbcb4fa Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 5 Nov 2021 10:35:37 +0000 Subject: [PATCH] format code --- paddlespeech/s2t/exps/u2/model.py | 8 ++--- paddlespeech/s2t/exps/u2_kaldi/model.py | 8 ++--- paddlespeech/s2t/exps/u2_st/model.py | 8 ++--- paddlespeech/s2t/transform/cmvn.py | 11 ++++-- paddlespeech/s2t/transform/perturb.py | 2 ++ paddlespeech/s2t/transform/transformation.py | 3 +- paddlespeech/s2t/utils/ctc_utils.py | 3 +- utils/remove_longshortdata.py | 38 +++++++++++--------- 8 files changed, 46 insertions(+), 35 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 8dad5074..7eed9391 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -575,10 +575,10 @@ class U2Tester(U2Trainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align(self.config, - self.model, self.align_loader, self.config.decoding.batch_size, - self.config.collator.stride_ms, - self.vocab_list, self.args.result_file) + ctc_utils.ctc_align(self.config, self.model, self.align_loader, + self.config.decoding.batch_size, + self.config.collator.stride_ms, self.vocab_list, + self.args.result_file) def load_inferspec(self): """infer model and input spec. diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index 6c4365b8..d82034c8 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -528,10 +528,10 @@ class U2Tester(U2Trainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align(self.config, - self.model, self.align_loader, self.config.decoding.batch_size, - self.config.collator.stride_ms, - self.vocab_list, self.args.result_file) + ctc_utils.ctc_align(self.config, self.model, self.align_loader, + self.config.decoding.batch_size, + self.config.collator.stride_ms, self.vocab_list, + self.args.result_file) def load_inferspec(self): """infer model and input spec. diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 9141b361..91390afe 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -543,10 +543,10 @@ class U2STTester(U2STTrainer): @paddle.no_grad() def align(self): - ctc_utils.ctc_align(self.config, - self.model, self.align_loader, self.config.decoding.batch_size, - self.config.collator.stride_ms, - self.vocab_list, self.args.result_file) + ctc_utils.ctc_align(self.config, self.model, self.align_loader, + self.config.decoding.batch_size, + self.config.collator.stride_ms, self.vocab_list, + self.args.result_file) def load_inferspec(self): """infer model and input spec. diff --git a/paddlespeech/s2t/transform/cmvn.py b/paddlespeech/s2t/transform/cmvn.py index dc9ea87e..aa1e6b44 100644 --- a/paddlespeech/s2t/transform/cmvn.py +++ b/paddlespeech/s2t/transform/cmvn.py @@ -14,10 +14,12 @@ # Modified from espnet(https://github.com/espnet/espnet) import io import json + import h5py import kaldiio import numpy as np + class CMVN(): "Apply Global/Spk CMVN/iverserCMVN." @@ -158,11 +160,14 @@ class UtteranceCMVN(): return x - class GlobalCMVN(): "Apply Global CMVN" - def __init__(self, cmvn_path, norm_means=True, norm_vars=True, std_floor=1.0e-20): + def __init__(self, + cmvn_path, + norm_means=True, + norm_vars=True, + std_floor=1.0e-20): self.cmvn_path = cmvn_path self.norm_means = norm_means self.norm_vars = norm_vars @@ -189,4 +194,4 @@ class GlobalCMVN(): if self.norm_vars: x = np.divide(x, self.std) - return x \ No newline at end of file + return x diff --git a/paddlespeech/s2t/transform/perturb.py b/paddlespeech/s2t/transform/perturb.py index ee4c7ce0..873adb0b 100644 --- a/paddlespeech/s2t/transform/perturb.py +++ b/paddlespeech/s2t/transform/perturb.py @@ -17,6 +17,7 @@ import numpy import scipy import soundfile import soxbindings as sox + from paddlespeech.s2t.io.reader import SoundHDF5File @@ -171,6 +172,7 @@ class SpeedPerturbationSox(): upper={self.upper}, keep_length={self.keep_length}, sample_rate={self.sr})""" + else: return f"""{self.__class__.__name__}( utt2ratio={self.utt2ratio_file}, diff --git a/paddlespeech/s2t/transform/transformation.py b/paddlespeech/s2t/transform/transformation.py index bfe6c53d..381b0cdc 100644 --- a/paddlespeech/s2t/transform/transformation.py +++ b/paddlespeech/s2t/transform/transformation.py @@ -46,8 +46,7 @@ import_alias = dict( wpe="paddlespeech.s2t.transform.wpe:WPE", channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi", - cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN" -) + cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN") class Transformation(): diff --git a/paddlespeech/s2t/utils/ctc_utils.py b/paddlespeech/s2t/utils/ctc_utils.py index f5822e5d..886b7203 100644 --- a/paddlespeech/s2t/utils/ctc_utils.py +++ b/paddlespeech/s2t/utils/ctc_utils.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from wenet(https://github.com/wenet-e2e/wenet) -from typing import List from pathlib import Path +from typing import List + import numpy as np import paddle diff --git a/utils/remove_longshortdata.py b/utils/remove_longshortdata.py index dcc05b23..131b4a58 100755 --- a/utils/remove_longshortdata.py +++ b/utils/remove_longshortdata.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 """remove longshort data from manifest""" -import logging import argparse +import logging + import jsonlines from paddlespeech.s2t.utils.cli_utils import get_commandline_args @@ -23,17 +24,19 @@ def get_parser(): parser.add_argument( "--verbose", "-V", default=0, type=int, help="Verbose option") parser.add_argument( - "--iaxis", default=0, type=int, help="multi inputs index, 0 is the first") - parser.add_argument( - "--oaxis", default=0, type=int, help="multi outputs index, 0 is the first") - parser.add_argument( - "--maxframes", default=2000, type=int, help="maxframes") - parser.add_argument( - "--minframes", default=10, type=int, help="minframes") + "--iaxis", + default=0, + type=int, + help="multi inputs index, 0 is the first") parser.add_argument( - "--maxchars", default=200, type=int, help="max tokens") - parser.add_argument( - "--minchars", default=0, type=int, help="min tokens") + "--oaxis", + default=0, + type=int, + help="multi outputs index, 0 is the first") + parser.add_argument("--maxframes", default=2000, type=int, help="maxframes") + parser.add_argument("--minframes", default=10, type=int, help="minframes") + parser.add_argument("--maxchars", default=200, type=int, help="max tokens") + parser.add_argument("--minchars", default=0, type=int, help="min tokens") parser.add_argument( "--stride_ms", default=10, type=int, help="stride in ms unit.") parser.add_argument( @@ -54,7 +57,7 @@ def filter_input(args, line): nframe = tmp['shape'][0] * 1000 / args.stride_ms else: nframe = tmp['shape'][0] - + if nframe < args.minframes or nframe > args.maxframes: return True else: @@ -67,7 +70,7 @@ def filter_output(args, line): return True else: return False - + def main(): args = get_parser().parse_args() @@ -78,15 +81,15 @@ def main(): else: logging.basicConfig(level=logging.WARN, format=logfmt) logging.info(get_commandline_args()) - + with jsonlines.open(args.rspecifier, 'r') as reader: lines = list(reader) logging.info(f"Example: {len(lines)}") feat = lines[0]['input'][args.iaxis]['feat'] - args.soud = False + args.soud = False if feat.split('.')[-1] not in 'ark, scp': args.sound = True - + count = 0 filter = 0 with jsonlines.open(args.wspecifier_or_wxfilename, 'w') as writer: @@ -98,5 +101,6 @@ def main(): count += 1 logging.info(f"Example after filter: {count}\{filter}") + if __name__ == '__main__': - main() \ No newline at end of file + main()