format code

pull/1012/head
Hui Zhang 3 years ago
parent 7b3a901b08
commit fb853167d3

@ -575,10 +575,10 @@ class U2Tester(U2Trainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size,
self.config.collator.stride_ms,
self.vocab_list, self.args.result_file)
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.batch_size,
self.config.collator.stride_ms, self.vocab_list,
self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.

@ -528,10 +528,10 @@ class U2Tester(U2Trainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size,
self.config.collator.stride_ms,
self.vocab_list, self.args.result_file)
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.batch_size,
self.config.collator.stride_ms, self.vocab_list,
self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.

@ -543,10 +543,10 @@ class U2STTester(U2STTrainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(self.config,
self.model, self.align_loader, self.config.decoding.batch_size,
self.config.collator.stride_ms,
self.vocab_list, self.args.result_file)
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.batch_size,
self.config.collator.stride_ms, self.vocab_list,
self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.

@ -14,10 +14,12 @@
# Modified from espnet(https://github.com/espnet/espnet)
import io
import json
import h5py
import kaldiio
import numpy as np
class CMVN():
"Apply Global/Spk CMVN/iverserCMVN."
@ -158,11 +160,14 @@ class UtteranceCMVN():
return x
class GlobalCMVN():
"Apply Global CMVN"
def __init__(self, cmvn_path, norm_means=True, norm_vars=True, std_floor=1.0e-20):
def __init__(self,
cmvn_path,
norm_means=True,
norm_vars=True,
std_floor=1.0e-20):
self.cmvn_path = cmvn_path
self.norm_means = norm_means
self.norm_vars = norm_vars
@ -189,4 +194,4 @@ class GlobalCMVN():
if self.norm_vars:
x = np.divide(x, self.std)
return x
return x

@ -17,6 +17,7 @@ import numpy
import scipy
import soundfile
import soxbindings as sox
from paddlespeech.s2t.io.reader import SoundHDF5File
@ -171,6 +172,7 @@ class SpeedPerturbationSox():
upper={self.upper},
keep_length={self.keep_length},
sample_rate={self.sr})"""
else:
return f"""{self.__class__.__name__}(
utt2ratio={self.utt2ratio_file},

@ -46,8 +46,7 @@ import_alias = dict(
wpe="paddlespeech.s2t.transform.wpe:WPE",
channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector",
fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi",
cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN"
)
cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN")
class Transformation():

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import List
from pathlib import Path
from typing import List
import numpy as np
import paddle

@ -1,7 +1,8 @@
#!/usr/bin/env python3
"""remove longshort data from manifest"""
import logging
import argparse
import logging
import jsonlines
from paddlespeech.s2t.utils.cli_utils import get_commandline_args
@ -23,17 +24,19 @@ def get_parser():
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--iaxis", default=0, type=int, help="multi inputs index, 0 is the first")
parser.add_argument(
"--oaxis", default=0, type=int, help="multi outputs index, 0 is the first")
parser.add_argument(
"--maxframes", default=2000, type=int, help="maxframes")
parser.add_argument(
"--minframes", default=10, type=int, help="minframes")
"--iaxis",
default=0,
type=int,
help="multi inputs index, 0 is the first")
parser.add_argument(
"--maxchars", default=200, type=int, help="max tokens")
parser.add_argument(
"--minchars", default=0, type=int, help="min tokens")
"--oaxis",
default=0,
type=int,
help="multi outputs index, 0 is the first")
parser.add_argument("--maxframes", default=2000, type=int, help="maxframes")
parser.add_argument("--minframes", default=10, type=int, help="minframes")
parser.add_argument("--maxchars", default=200, type=int, help="max tokens")
parser.add_argument("--minchars", default=0, type=int, help="min tokens")
parser.add_argument(
"--stride_ms", default=10, type=int, help="stride in ms unit.")
parser.add_argument(
@ -54,7 +57,7 @@ def filter_input(args, line):
nframe = tmp['shape'][0] * 1000 / args.stride_ms
else:
nframe = tmp['shape'][0]
if nframe < args.minframes or nframe > args.maxframes:
return True
else:
@ -67,7 +70,7 @@ def filter_output(args, line):
return True
else:
return False
def main():
args = get_parser().parse_args()
@ -78,15 +81,15 @@ def main():
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
with jsonlines.open(args.rspecifier, 'r') as reader:
lines = list(reader)
logging.info(f"Example: {len(lines)}")
feat = lines[0]['input'][args.iaxis]['feat']
args.soud = False
args.soud = False
if feat.split('.')[-1] not in 'ark, scp':
args.sound = True
count = 0
filter = 0
with jsonlines.open(args.wspecifier_or_wxfilename, 'w') as writer:
@ -98,5 +101,6 @@ def main():
count += 1
logging.info(f"Example after filter: {count}\{filter}")
if __name__ == '__main__':
main()
main()

Loading…
Cancel
Save