format code

pull/1012/head
Hui Zhang 4 years ago
parent 7b3a901b08
commit fb853167d3

@ -575,10 +575,10 @@ class U2Tester(U2Trainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
ctc_utils.ctc_align(self.config, ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.model, self.align_loader, self.config.decoding.batch_size, self.config.decoding.batch_size,
self.config.collator.stride_ms, self.config.collator.stride_ms, self.vocab_list,
self.vocab_list, self.args.result_file) self.args.result_file)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.

@ -528,10 +528,10 @@ class U2Tester(U2Trainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
ctc_utils.ctc_align(self.config, ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.model, self.align_loader, self.config.decoding.batch_size, self.config.decoding.batch_size,
self.config.collator.stride_ms, self.config.collator.stride_ms, self.vocab_list,
self.vocab_list, self.args.result_file) self.args.result_file)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.

@ -543,10 +543,10 @@ class U2STTester(U2STTrainer):
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
ctc_utils.ctc_align(self.config, ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.model, self.align_loader, self.config.decoding.batch_size, self.config.decoding.batch_size,
self.config.collator.stride_ms, self.config.collator.stride_ms, self.vocab_list,
self.vocab_list, self.args.result_file) self.args.result_file)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.

@ -14,10 +14,12 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
import io import io
import json import json
import h5py import h5py
import kaldiio import kaldiio
import numpy as np import numpy as np
class CMVN(): class CMVN():
"Apply Global/Spk CMVN/iverserCMVN." "Apply Global/Spk CMVN/iverserCMVN."
@ -158,11 +160,14 @@ class UtteranceCMVN():
return x return x
class GlobalCMVN(): class GlobalCMVN():
"Apply Global CMVN" "Apply Global CMVN"
def __init__(self, cmvn_path, norm_means=True, norm_vars=True, std_floor=1.0e-20): def __init__(self,
cmvn_path,
norm_means=True,
norm_vars=True,
std_floor=1.0e-20):
self.cmvn_path = cmvn_path self.cmvn_path = cmvn_path
self.norm_means = norm_means self.norm_means = norm_means
self.norm_vars = norm_vars self.norm_vars = norm_vars
@ -189,4 +194,4 @@ class GlobalCMVN():
if self.norm_vars: if self.norm_vars:
x = np.divide(x, self.std) x = np.divide(x, self.std)
return x return x

@ -17,6 +17,7 @@ import numpy
import scipy import scipy
import soundfile import soundfile
import soxbindings as sox import soxbindings as sox
from paddlespeech.s2t.io.reader import SoundHDF5File from paddlespeech.s2t.io.reader import SoundHDF5File
@ -171,6 +172,7 @@ class SpeedPerturbationSox():
upper={self.upper}, upper={self.upper},
keep_length={self.keep_length}, keep_length={self.keep_length},
sample_rate={self.sr})""" sample_rate={self.sr})"""
else: else:
return f"""{self.__class__.__name__}( return f"""{self.__class__.__name__}(
utt2ratio={self.utt2ratio_file}, utt2ratio={self.utt2ratio_file},

@ -46,8 +46,7 @@ import_alias = dict(
wpe="paddlespeech.s2t.transform.wpe:WPE", wpe="paddlespeech.s2t.transform.wpe:WPE",
channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector",
fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi", fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi",
cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN" cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN")
)
class Transformation(): class Transformation():

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import List
from pathlib import Path from pathlib import Path
from typing import List
import numpy as np import numpy as np
import paddle import paddle

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""remove longshort data from manifest""" """remove longshort data from manifest"""
import logging
import argparse import argparse
import logging
import jsonlines import jsonlines
from paddlespeech.s2t.utils.cli_utils import get_commandline_args from paddlespeech.s2t.utils.cli_utils import get_commandline_args
@ -23,17 +24,19 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option") "--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument( parser.add_argument(
"--iaxis", default=0, type=int, help="multi inputs index, 0 is the first") "--iaxis",
parser.add_argument( default=0,
"--oaxis", default=0, type=int, help="multi outputs index, 0 is the first") type=int,
parser.add_argument( help="multi inputs index, 0 is the first")
"--maxframes", default=2000, type=int, help="maxframes")
parser.add_argument(
"--minframes", default=10, type=int, help="minframes")
parser.add_argument( parser.add_argument(
"--maxchars", default=200, type=int, help="max tokens") "--oaxis",
parser.add_argument( default=0,
"--minchars", default=0, type=int, help="min tokens") type=int,
help="multi outputs index, 0 is the first")
parser.add_argument("--maxframes", default=2000, type=int, help="maxframes")
parser.add_argument("--minframes", default=10, type=int, help="minframes")
parser.add_argument("--maxchars", default=200, type=int, help="max tokens")
parser.add_argument("--minchars", default=0, type=int, help="min tokens")
parser.add_argument( parser.add_argument(
"--stride_ms", default=10, type=int, help="stride in ms unit.") "--stride_ms", default=10, type=int, help="stride in ms unit.")
parser.add_argument( parser.add_argument(
@ -54,7 +57,7 @@ def filter_input(args, line):
nframe = tmp['shape'][0] * 1000 / args.stride_ms nframe = tmp['shape'][0] * 1000 / args.stride_ms
else: else:
nframe = tmp['shape'][0] nframe = tmp['shape'][0]
if nframe < args.minframes or nframe > args.maxframes: if nframe < args.minframes or nframe > args.maxframes:
return True return True
else: else:
@ -67,7 +70,7 @@ def filter_output(args, line):
return True return True
else: else:
return False return False
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
@ -78,15 +81,15 @@ def main():
else: else:
logging.basicConfig(level=logging.WARN, format=logfmt) logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args()) logging.info(get_commandline_args())
with jsonlines.open(args.rspecifier, 'r') as reader: with jsonlines.open(args.rspecifier, 'r') as reader:
lines = list(reader) lines = list(reader)
logging.info(f"Example: {len(lines)}") logging.info(f"Example: {len(lines)}")
feat = lines[0]['input'][args.iaxis]['feat'] feat = lines[0]['input'][args.iaxis]['feat']
args.soud = False args.soud = False
if feat.split('.')[-1] not in 'ark, scp': if feat.split('.')[-1] not in 'ark, scp':
args.sound = True args.sound = True
count = 0 count = 0
filter = 0 filter = 0
with jsonlines.open(args.wspecifier_or_wxfilename, 'w') as writer: with jsonlines.open(args.wspecifier_or_wxfilename, 'w') as writer:
@ -98,5 +101,6 @@ def main():
count += 1 count += 1
logging.info(f"Example after filter: {count}\{filter}") logging.info(f"Example after filter: {count}\{filter}")
if __name__ == '__main__': if __name__ == '__main__':
main() main()

Loading…
Cancel
Save