support multi-gpu training with webdataset

pull/2062/head
huangyuxin 3 years ago
parent 8f5e61090b
commit c7a7b113c8

@ -50,26 +50,41 @@ test_manifest: data/manifest.test
########################################### ###########################################
# Dataloader # # Dataloader #
########################################### ###########################################
vocab_filepath: data/lang_char/vocab.txt use_stream_data: True
unit_type: 'char' unit_type: 'char'
vocab_filepath: data/lang_char/vocab.txt
cmvn_file: data/mean_std.json
preprocess_config: conf/preprocess.yaml preprocess_config: conf/preprocess.yaml
spm_model_prefix: '' spm_model_prefix: ''
feat_dim: 80 feat_dim: 80
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
dither: 0.1
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 64 batch_size: 64
minlen_in: 10
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
minlen_out: 0
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug resample_rate: 16000
batch_count: auto shuffle_size: 10000
batch_bins: 0 sort_size: 500
batch_frames_in: 0 num_workers: 4
batch_frames_out: 0 prefetch_factor: 100
batch_frames_inout: 0 dist_sampler: True
num_workers: 0
subsampling_factor: 1
num_encs: 1 num_encs: 1
augment_conf:
max_w: 80
w_inplace: True
w_mode: "PIL"
max_f: 30
num_f_mask: 2
f_inplace: True
f_replace_with_zero: False
max_t: 40
num_t_mask: 2
t_inplace: True
t_replace_with_zero: False
########################################### ###########################################
@ -78,7 +93,7 @@ num_encs: 1
n_epoch: 240 n_epoch: 240
accum_grad: 16 accum_grad: 16
global_grad_clip: 5.0 global_grad_clip: 5.0
log_interval: 100 log_interval: 1
checkpoint: checkpoint:
kbest_n: 50 kbest_n: 50
latest_n: 5 latest_n: 5

@ -41,7 +41,8 @@ from .filters import (
spec_aug, spec_aug,
sort, sort,
padding, padding,
cmvn cmvn,
placeholder,
) )
from webdataset.handlers import ( from webdataset.handlers import (
ignore_and_continue, ignore_and_continue,

@ -758,17 +758,34 @@ def _compute_fbank(source,
compute_fbank = pipelinefilter(_compute_fbank) compute_fbank = pipelinefilter(_compute_fbank)
def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80): def _spec_aug(source,
max_w=5,
w_inplace=True,
w_mode="PIL",
max_f=30,
num_f_mask=2,
f_inplace=True,
f_replace_with_zero=False,
max_t=40,
num_t_mask=2,
t_inplace=True,
t_replace_with_zero=False,):
""" Do spec augmentation """ Do spec augmentation
Inplace operation Inplace operation
Args: Args:
source: Iterable[{fname, feat, label}] source: Iterable[{fname, feat, label}]
num_t_mask: number of time mask to apply max_w: max width of time warp
w_inplace: whether to inplace the original data while time warping
w_mode: time warp mode
max_f: max width of freq mask
num_f_mask: number of freq mask to apply num_f_mask: number of freq mask to apply
f_inplace: whether to inplace the original data while frequency masking
f_replace_with_zero: use zero to mask
max_t: max width of time mask max_t: max width of time mask
max_f: max width of freq mask num_t_mask: number of time mask to apply
max_w: max width of time warp t_inplace: whether to inplace the original data while time masking
t_replace_with_zero: use zero to mask
Returns Returns
Iterable[{fname, feat, label}] Iterable[{fname, feat, label}]
@ -776,9 +793,9 @@ def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80):
for sample in source: for sample in source:
x = sample['feat'] x = sample['feat']
x = x.numpy() x = x.numpy()
x = time_warp(x, max_time_warp=max_w, inplace = True, mode= "PIL") x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode)
x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = True, replace_with_zero = False) x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero)
x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = True, replace_with_zero = False) x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero)
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
yield sample yield sample
@ -910,3 +927,9 @@ def _cmvn(source, cmvn_file):
label_lengths) label_lengths)
cmvn = pipelinefilter(_cmvn) cmvn = pipelinefilter(_cmvn)
def _placeholder(source):
for data in source:
yield data
placeholder = pipelinefilter(_placeholder)

@ -89,6 +89,12 @@ class DataPipeline(IterableDataset, PipelineStage):
def append(self, f): def append(self, f):
"""Append a pipeline stage (modifies the object).""" """Append a pipeline stage (modifies the object)."""
self.pipeline.append(f) self.pipeline.append(f)
return self
def append_list(self, *args):
for arg in args:
self.pipeline.append(arg)
return self
def compose(self, *args): def compose(self, *args):
"""Append a pipeline stage to a copy of the pipeline and returns the copy.""" """Append a pipeline stage to a copy of the pipeline and returns the copy."""

@ -24,6 +24,8 @@ from .filters import pipelinefilter
from .paddle_utils import IterableDataset from .paddle_utils import IterableDataset
from ..utils.log import Logger
logger = Logger(__name__)
def expand_urls(urls): def expand_urls(urls):
if isinstance(urls, str): if isinstance(urls, str):
urllist = urls.split("::") urllist = urls.split("::")

@ -65,6 +65,7 @@ class Logger(object):
def __init__(self, name: str=None): def __init__(self, name: str=None):
name = 'PaddleAudio' if not name else name name = 'PaddleAudio' if not name else name
self.name = name
self.logger = logging.getLogger(name) self.logger = logging.getLogger(name)
for key, conf in log_config.items(): for key, conf in log_config.items():
@ -101,7 +102,7 @@ class Logger(object):
if not self.is_enable: if not self.is_enable:
return return
self.logger.log(log_level, msg) self.logger.log(log_level, self.name + " | " + msg)
@contextlib.contextmanager @contextlib.contextmanager
def use_terminator(self, terminator: str): def use_terminator(self, terminator: str):

@ -93,9 +93,6 @@ def pad_sequence(sequences: List[paddle.Tensor],
for i, tensor in enumerate(sequences): for i, tensor in enumerate(sequences):
length = tensor.shape[0] length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor # use index notation to prevent duplicate references to the tensor
logger.info(
f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}"
)
if batch_first: if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16 # TODO (Hui Zhang): set_value op not support int16

@ -26,6 +26,7 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
@ -106,7 +107,8 @@ class U2Trainer(Trainer):
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
self.model.eval() self.model.eval()
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
@ -132,7 +134,7 @@ class U2Trainer(Trainer):
msg = f"Valid: Rank: {dist.get_rank()}, " msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) #msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
@ -152,7 +154,8 @@ class U2Trainer(Trainer):
self.before_train() self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
@ -170,7 +173,8 @@ class U2Trainer(Trainer):
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch() self.after_train_batch()
report('iter', batch_index + 1) report('iter', batch_index + 1)
report('total', len(self.train_loader)) if not self.use_streamdata:
report('total', len(self.train_loader))
report('reader_cost', dataload_time) report('reader_cost', dataload_time)
observation['batch_cost'] = observation[ observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost'] 'reader_cost'] + observation['step_cost']
@ -218,92 +222,188 @@ class U2Trainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
# train/valid dataset, return token ids # train/valid dataset, return token ids
self.train_loader = BatchDataLoader( if self.use_streamdata:
json_file=config.train_manifest, self.train_loader = StreamDataLoader(
train_mode=True, manifest_file=config.train_manifest,
sortagrad=config.sortagrad, train_mode=True,
batch_size=config.batch_size, unit_type=config.unit_type,
maxlen_in=config.maxlen_in, batch_size=config.batch_size,
maxlen_out=config.maxlen_out, num_mel_bins=config.feat_dim,
minibatches=config.minibatches, frame_length=config.window_ms,
mini_batch_size=self.args.ngpu, frame_shift=config.stride_ms,
batch_count=config.batch_count, dither=config.dither,
batch_bins=config.batch_bins, minlen_in=config.minlen_in,
batch_frames_in=config.batch_frames_in, maxlen_in=config.maxlen_in,
batch_frames_out=config.batch_frames_out, minlen_out=config.minlen_out,
batch_frames_inout=config.batch_frames_inout, maxlen_out=config.maxlen_out,
preprocess_conf=config.preprocess_config, resample_rate=config.resample_rate,
n_iter_processes=config.num_workers, augment_conf=config.augment_conf, # dict
subsampling_factor=1, shuffle_size=config.shuffle_size,
num_encs=1, sort_size=config.sort_size,
dist_sampler=config.get('dist_sampler', False), n_iter_processes=config.num_workers,
shortest_first=False) prefetch_factor=config.prefetch_factor,
dist_sampler=config.get('dist_sampler', False),
self.valid_loader = BatchDataLoader( cmvn_file=config.cmvn_file,
json_file=config.dev_manifest, vocab_filepath=config.vocab_filepath,
train_mode=False, )
sortagrad=False, self.valid_loader = StreamDataLoader(
batch_size=config.batch_size, manifest_file=config.dev_manifest,
maxlen_in=float('inf'), train_mode=False,
maxlen_out=float('inf'), unit_type=config.unit_type,
minibatches=0, batch_size=config.batch_size,
mini_batch_size=self.args.ngpu, num_mel_bins=config.feat_dim,
batch_count='auto', frame_length=config.window_ms,
batch_bins=0, frame_shift=config.stride_ms,
batch_frames_in=0, dither=config.dither,
batch_frames_out=0, minlen_in=config.minlen_in,
batch_frames_inout=0, maxlen_in=config.maxlen_in,
preprocess_conf=config.preprocess_config, minlen_out=config.minlen_out,
n_iter_processes=config.num_workers, maxlen_out=config.maxlen_out,
subsampling_factor=1, resample_rate=config.resample_rate,
num_encs=1, augment_conf=config.augment_conf, # dict
dist_sampler=config.get('dist_sampler', False), shuffle_size=config.shuffle_size,
shortest_first=False) sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.get('dist_sampler', False),
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
)
else:
self.train_loader = BatchDataLoader(
json_file=config.train_manifest,
train_mode=True,
sortagrad=config.sortagrad,
batch_size=config.batch_size,
maxlen_in=config.maxlen_in,
maxlen_out=config.maxlen_out,
minibatches=config.minibatches,
mini_batch_size=self.args.ngpu,
batch_count=config.batch_count,
batch_bins=config.batch_bins,
batch_frames_in=config.batch_frames_in,
batch_frames_out=config.batch_frames_out,
batch_frames_inout=config.batch_frames_inout,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=config.get('dist_sampler', False),
shortest_first=False)
self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.ngpu,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=config.get('dist_sampler', False),
shortest_first=False)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
decode_batch_size = config.get('decode', dict()).get( decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1) 'decode_batch_size', 1)
# test dataset, return raw text # test dataset, return raw text
self.test_loader = BatchDataLoader( if self.use_streamdata:
json_file=config.test_manifest, self.test_loader = StreamDataLoader(
train_mode=False, manifest_file=config.test_manifest,
sortagrad=False, train_mode=False,
batch_size=decode_batch_size, unit_type=config.unit_type,
maxlen_in=float('inf'), batch_size=config.batch_size,
maxlen_out=float('inf'), num_mel_bins=config.feat_dim,
minibatches=0, frame_length=config.window_ms,
mini_batch_size=1, frame_shift=config.stride_ms,
batch_count='auto', dither=0.0,
batch_bins=0, minlen_in=0.0,
batch_frames_in=0, maxlen_in=float('inf'),
batch_frames_out=0, minlen_out=0,
batch_frames_inout=0, maxlen_out=float('inf'),
preprocess_conf=config.preprocess_config, resample_rate=config.resample_rate,
n_iter_processes=1, augment_conf=config.augment_conf, # dict
subsampling_factor=1, shuffle_size=config.shuffle_size,
num_encs=1) sort_size=config.sort_size,
n_iter_processes=config.num_workers,
self.align_loader = BatchDataLoader( prefetch_factor=config.prefetch_factor,
json_file=config.test_manifest, dist_sampler=config.get('dist_sampler', False),
train_mode=False, cmvn_file=config.cmvn_file,
sortagrad=False, vocab_filepath=config.vocab_filepath,
batch_size=decode_batch_size, )
maxlen_in=float('inf'), self.align_loader = StreamDataLoader(
maxlen_out=float('inf'), manifest_file=config.test_manifest,
minibatches=0, train_mode=False,
mini_batch_size=1, unit_type=config.unit_type,
batch_count='auto', batch_size=config.batch_size,
batch_bins=0, num_mel_bins=config.feat_dim,
batch_frames_in=0, frame_length=config.window_ms,
batch_frames_out=0, frame_shift=config.stride_ms,
batch_frames_inout=0, dither=0.0,
preprocess_conf=config.preprocess_config, minlen_in=0.0,
n_iter_processes=1, maxlen_in=float('inf'),
subsampling_factor=1, minlen_out=0,
num_encs=1) maxlen_out=float('inf'),
resample_rate=config.resample_rate,
augment_conf=config.augment_conf, # dict
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.get('dist_sampler', False),
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
)
else:
self.test_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
self.align_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
logger.info("Setup test/align Dataloader!") logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):
@ -452,7 +552,8 @@ class U2Tester(U2Trainer):
def test(self): def test(self):
assert self.args.result_file assert self.args.result_file
self.model.eval() self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") if not self.use_streamdata:
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.config.stride_ms stride_ms = self.config.stride_ms
error_rate_type = None error_rate_type = None

@ -28,6 +28,9 @@ from paddlespeech.s2t.io.dataset import TransformDataset
from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
import paddlespeech.audio.stream_data as stream_data
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
__all__ = ["BatchDataLoader"] __all__ = ["BatchDataLoader"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -56,6 +59,90 @@ def batch_collate(x):
""" """
return x[0] return x[0]
class StreamDataLoader():
def __init__(self,
manifest_file: str,
train_mode: bool,
unit_type: str='char',
batch_size: int=0,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
minlen_in: float=0.0,
maxlen_in: float=float('inf'),
minlen_out: float=0.0,
maxlen_out: float=float('inf'),
resample_rate: int=16000,
augment_conf: dict=None,
shuffle_size: int=10000,
sort_size: int=1000,
n_iter_processes: int=1,
prefetch_factor: int=2,
dist_sampler: bool=False,
cmvn_file="data/mean_std.json",
vocab_filepath='data/lang_char/vocab.txt'):
self.manifest_file = manifest_file
self.train_model = train_mode
self.batch_size = batch_size
self.prefetch_factor = prefetch_factor
self.dist_sampler = dist_sampler
self.n_iter_processes = n_iter_processes
text_featurizer = TextFeaturizer(unit_type, vocab_filepath)
symbol_table = text_featurizer.vocab_dict
self.feat_dim = num_mel_bins
self.vocab_size = text_featurizer.vocab_size
# The list of shard
shardlist = []
with open(manifest_file, "r") as f:
for line in f.readlines():
shardlist.append(line.strip())
if self.dist_sampler:
base_dataset = stream_data.DataPipeline(
stream_data.SimpleShardList(shardlist),
stream_data.split_by_node,
stream_data.split_by_worker,
stream_data.tarfile_to_samples(stream_data.reraise_exception)
)
else:
base_dataset = stream_data.DataPipeline(
stream_data.SimpleShardList(shardlist),
stream_data.split_by_worker,
stream_data.tarfile_to_samples(stream_data.reraise_exception)
)
self.dataset = base_dataset.append_list(
stream_data.tokenize(symbol_table),
stream_data.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in),
stream_data.resample(resample_rate=resample_rate),
stream_data.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
stream_data.spec_aug(**augment_conf) if train_mode else stream_data.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
stream_data.shuffle(shuffle_size),
stream_data.sort(sort_size=sort_size),
stream_data.batched(batch_size),
stream_data.padding(),
stream_data.cmvn(cmvn_file)
)
self.loader = stream_data.WebLoader(
self.dataset,
num_workers=self.n_iter_processes,
prefetch_factor = self.prefetch_factor,
batch_size=None
)
def __iter__(self):
return self.loader.__iter__()
def __call__(self):
return self.__iter__()
def __len__(self):
logger.info("Stream dataloader does not support calculate the length of the dataset")
return -1
class BatchDataLoader(): class BatchDataLoader():
def __init__(self, def __init__(self,

@ -19,7 +19,7 @@ import numpy as np
import soundfile import soundfile
from .utility import feat_type from .utility import feat_type
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
# from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation # from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation

@ -1,13 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,54 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import numpy as np
def delta(feat, window):
assert window > 0
delta_feat = np.zeros_like(feat)
for i in range(1, window + 1):
delta_feat[:-i] += i * feat[i:]
delta_feat[i:] += -i * feat[:-i]
delta_feat[-i:] += i * feat[-1]
delta_feat[:i] += -i * feat[0]
delta_feat /= 2 * sum(i**2 for i in range(1, window + 1))
return delta_feat
def add_deltas(x, window=2, order=2):
"""
Args:
x (np.ndarray): speech feat, (T, D).
Return:
np.ndarray: (T, (1+order)*D)
"""
feats = [x]
for _ in range(order):
feats.append(delta(feats[-1], window))
return np.concatenate(feats, axis=1)
class AddDeltas():
def __init__(self, window=2, order=2):
self.window = window
self.order = order
def __repr__(self):
return "{name}(window={window}, order={order}".format(
name=self.__class__.__name__, window=self.window, order=self.order)
def __call__(self, x):
return add_deltas(x, window=self.window, order=self.order)

@ -1,57 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import numpy
class ChannelSelector():
"""Select 1ch from multi-channel signal"""
def __init__(self, train_channel="random", eval_channel=0, axis=1):
self.train_channel = train_channel
self.eval_channel = eval_channel
self.axis = axis
def __repr__(self):
return ("{name}(train_channel={train_channel}, "
"eval_channel={eval_channel}, axis={axis})".format(
name=self.__class__.__name__,
train_channel=self.train_channel,
eval_channel=self.eval_channel,
axis=self.axis, ))
def __call__(self, x, train=True):
# Assuming x: [Time, Channel] by default
if x.ndim <= self.axis:
# If the dimension is insufficient, then unsqueeze
# (e.g [Time] -> [Time, 1])
ind = tuple(
slice(None) if i < x.ndim else None
for i in range(self.axis + 1))
x = x[ind]
if train:
channel = self.train_channel
else:
channel = self.eval_channel
if channel == "random":
ch = numpy.random.randint(0, x.shape[self.axis])
else:
ch = channel
ind = tuple(
slice(None) if i != self.axis else ch for i in range(x.ndim))
return x[ind]

@ -1,201 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import io
import json
import h5py
import kaldiio
import numpy as np
class CMVN():
"Apply Global/Spk CMVN/iverserCMVN."
def __init__(
self,
stats,
norm_means=True,
norm_vars=False,
filetype="mat",
utt2spk=None,
spk2utt=None,
reverse=False,
std_floor=1.0e-20, ):
self.stats_file = stats
self.norm_means = norm_means
self.norm_vars = norm_vars
self.reverse = reverse
if isinstance(stats, dict):
stats_dict = dict(stats)
else:
# Use for global CMVN
if filetype == "mat":
stats_dict = {None: kaldiio.load_mat(stats)}
# Use for global CMVN
elif filetype == "npy":
stats_dict = {None: np.load(stats)}
# Use for speaker CMVN
elif filetype == "ark":
self.accept_uttid = True
stats_dict = dict(kaldiio.load_ark(stats))
# Use for speaker CMVN
elif filetype == "hdf5":
self.accept_uttid = True
stats_dict = h5py.File(stats)
else:
raise ValueError("Not supporting filetype={}".format(filetype))
if utt2spk is not None:
self.utt2spk = {}
with io.open(utt2spk, "r", encoding="utf-8") as f:
for line in f:
utt, spk = line.rstrip().split(None, 1)
self.utt2spk[utt] = spk
elif spk2utt is not None:
self.utt2spk = {}
with io.open(spk2utt, "r", encoding="utf-8") as f:
for line in f:
spk, utts = line.rstrip().split(None, 1)
for utt in utts.split():
self.utt2spk[utt] = spk
else:
self.utt2spk = None
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
# and the first vector contains the sum of feats and the second is
# the sum of squares. The last value of the first, i.e. stats[0,-1],
# is the number of samples for this statistics.
self.bias = {}
self.scale = {}
for spk, stats in stats_dict.items():
assert len(stats) == 2, stats.shape
count = stats[0, -1]
# If the feature has two or more dimensions
if not (np.isscalar(count) or isinstance(count, (int, float))):
# The first is only used
count = count.flatten()[0]
mean = stats[0, :-1] / count
# V(x) = E(x^2) - (E(x))^2
var = stats[1, :-1] / count - mean * mean
std = np.maximum(np.sqrt(var), std_floor)
self.bias[spk] = -mean
self.scale[spk] = 1 / std
def __repr__(self):
return ("{name}(stats_file={stats_file}, "
"norm_means={norm_means}, norm_vars={norm_vars}, "
"reverse={reverse})".format(
name=self.__class__.__name__,
stats_file=self.stats_file,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
reverse=self.reverse, ))
def __call__(self, x, uttid=None):
if self.utt2spk is not None:
spk = self.utt2spk[uttid]
else:
spk = uttid
if not self.reverse:
# apply cmvn
if self.norm_means:
x = np.add(x, self.bias[spk])
if self.norm_vars:
x = np.multiply(x, self.scale[spk])
else:
# apply reverse cmvn
if self.norm_vars:
x = np.divide(x, self.scale[spk])
if self.norm_means:
x = np.subtract(x, self.bias[spk])
return x
class UtteranceCMVN():
"Apply Utterance CMVN"
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
def __repr__(self):
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
name=self.__class__.__name__,
norm_means=self.norm_means,
norm_vars=self.norm_vars, )
def __call__(self, x, uttid=None):
# x: [Time, Dim]
square_sums = (x**2).sum(axis=0)
mean = x.mean(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean**2
std = np.maximum(np.sqrt(var), self.std_floor)
x = np.divide(x, std)
return x
class GlobalCMVN():
"Apply Global CMVN"
def __init__(self,
cmvn_path,
norm_means=True,
norm_vars=True,
std_floor=1.0e-20):
# cmvn_path: Option[str, dict]
cmvn = cmvn_path
self.cmvn = cmvn
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
if isinstance(cmvn, dict):
cmvn_stats = cmvn
else:
with open(cmvn) as f:
cmvn_stats = json.load(f)
self.count = cmvn_stats['frame_num']
self.mean = np.array(cmvn_stats['mean_stat']) / self.count
self.square_sums = np.array(cmvn_stats['var_stat'])
self.var = self.square_sums / self.count - self.mean**2
self.std = np.maximum(np.sqrt(self.var), self.std_floor)
def __repr__(self):
return f"""{self.__class__.__name__}(
cmvn_path={self.cmvn},
norm_means={self.norm_means},
norm_vars={self.norm_vars},)"""
def __call__(self, x, uttid=None):
# x: [Time, Dim]
if self.norm_means:
x = np.subtract(x, self.mean)
if self.norm_vars:
x = np.divide(x, self.std)
return x

@ -1,86 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import inspect
from paddlespeech.s2t.transform.transform_interface import TransformInterface
from paddlespeech.s2t.utils.check_kwargs import check_kwargs
class FuncTrans(TransformInterface):
"""Functional Transformation
WARNING:
Builtin or C/C++ functions may not work properly
because this class heavily depends on the `inspect` module.
Usage:
>>> def foo_bar(x, a=1, b=2):
... '''Foo bar
... :param x: input
... :param int a: default 1
... :param int b: default 2
... '''
... return x + a - b
>>> class FooBar(FuncTrans):
... _func = foo_bar
... __doc__ = foo_bar.__doc__
"""
_func = None
def __init__(self, **kwargs):
self.kwargs = kwargs
check_kwargs(self.func, kwargs)
def __call__(self, x):
return self.func(x, **self.kwargs)
@classmethod
def add_arguments(cls, parser):
fname = cls._func.__name__.replace("_", "-")
group = parser.add_argument_group(fname + " transformation setting")
for k, v in cls.default_params().items():
# TODO(karita): get help and choices from docstring?
attr = k.replace("_", "-")
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
return parser
@property
def func(self):
return type(self)._func
@classmethod
def default_params(cls):
try:
d = dict(inspect.signature(cls._func).parameters)
except ValueError:
d = dict()
return {
k: v.default
for k, v in d.items() if v.default != inspect.Parameter.empty
}
def __repr__(self):
params = self.default_params()
params.update(**self.kwargs)
ret = self.__class__.__name__ + "("
if len(params) == 0:
return ret + ")"
for k, v in params.items():
ret += "{}={}, ".format(k, v)
return ret[:-2] + ")"

@ -1,471 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import librosa
import numpy
import scipy
import soundfile
from paddlespeech.s2t.io.reader import SoundHDF5File
class SpeedPerturbation():
"""SpeedPerturbation
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
Warning:
This function is very slow because of resampling.
I recommmend to apply speed-perturb outside the training using sox.
"""
def __init__(
self,
lower=0.9,
upper=1.1,
utt2ratio=None,
keep_length=True,
res_type="kaiser_best",
seed=None, ):
self.res_type = res_type
self.keep_length = keep_length
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
self.utt2ratio = {}
# Use the scheduled ratio for each utterances
self.utt2ratio_file = utt2ratio
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
self.utt2ratio = None
# The ratio is given on runtime randomly
self.lower = lower
self.upper = upper
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
self.__class__.__name__,
self.lower,
self.upper,
self.keep_length,
self.res_type, )
else:
return "{}({}, res_type={})".format(
self.__class__.__name__, self.utt2ratio_file, self.res_type)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
# Note1: resample requires the sampling-rate of input and output,
# but actually only the ratio is used.
y = librosa.resample(
x, orig_sr=ratio, target_sr=1, res_type=self.res_type)
if self.keep_length:
diff = abs(len(x) - len(y))
if len(y) > len(x):
# Truncate noise
y = y[diff // 2:-((diff + 1) // 2)]
elif len(y) < len(x):
# Assume the time-axis is the first: (Time, Channel)
pad_width = [(diff // 2, (diff + 1) // 2)] + [
(0, 0) for _ in range(y.ndim - 1)
]
y = numpy.pad(
y, pad_width=pad_width, constant_values=0, mode="constant")
return y
class SpeedPerturbationSox():
"""SpeedPerturbationSox
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
To speed up or slow down the sound of a file,
use speed to modify the pitch and the duration of the file.
This raises the speed and reduces the time.
The default factor is 1.0 which makes no change to the audio.
2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
tempo option:
sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9
speed option:
sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9
If we use speed option like above, the pitch of audio also will be changed,
but the tempo option does not change the pitch.
"""
def __init__(
self,
lower=0.9,
upper=1.1,
utt2ratio=None,
keep_length=True,
sr=16000,
seed=None, ):
self.sr = sr
self.keep_length = keep_length
self.state = numpy.random.RandomState(seed)
try:
import soxbindings as sox
except ImportError:
try:
from paddlespeech.s2t.utils import dynamic_pip_install
package = "sox"
dynamic_pip_install.install(package)
package = "soxbindings"
if sys.platform != "win32":
dynamic_pip_install.install(package)
import soxbindings as sox
except Exception:
raise RuntimeError(
"Can not install soxbindings on your system.")
self.sox = sox
if utt2ratio is not None:
self.utt2ratio = {}
# Use the scheduled ratio for each utterances
self.utt2ratio_file = utt2ratio
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
self.utt2ratio = None
# The ratio is given on runtime randomly
self.lower = lower
self.upper = upper
def __repr__(self):
if self.utt2ratio is None:
return f"""{self.__class__.__name__}(
lower={self.lower},
upper={self.upper},
keep_length={self.keep_length},
sample_rate={self.sr})"""
else:
return f"""{self.__class__.__name__}(
utt2ratio={self.utt2ratio_file},
sample_rate={self.sr})"""
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
tfm = self.sox.Transformer()
tfm.set_globals(multithread=False)
tfm.speed(ratio)
y = tfm.build_array(input_array=x, sample_rate_in=self.sr)
if self.keep_length:
diff = abs(len(x) - len(y))
if len(y) > len(x):
# Truncate noise
y = y[diff // 2:-((diff + 1) // 2)]
elif len(y) < len(x):
# Assume the time-axis is the first: (Time, Channel)
pad_width = [(diff // 2, (diff + 1) // 2)] + [
(0, 0) for _ in range(y.ndim - 1)
]
y = numpy.pad(
y, pad_width=pad_width, constant_values=0, mode="constant")
if y.ndim == 2 and x.ndim == 1:
# (T, C) -> (T)
y = y.sequence(1)
return y
class BandpassPerturbation():
"""BandpassPerturbation
Randomly dropout along the frequency axis.
The original idea comes from the following:
"randomly-selected frequency band was cut off under the constraint of
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
everyday home environments using multiple microphone arrays;
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
"""
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )):
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
# x_stft: (Time, Channel, Freq)
self.axes = axes
def __repr__(self):
return "{}(lower={}, upper={})".format(self.__class__.__name__,
self.lower, self.upper)
def __call__(self, x_stft, uttid=None, train=True):
if not train:
return x_stft
if x_stft.ndim == 1:
raise RuntimeError("Input in time-freq domain: "
"(Time, Channel, Freq) or (Time, Freq)")
ratio = self.state.uniform(self.lower, self.upper)
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
mask = self.state.randn(*shape) > ratio
x_stft *= mask
return x_stft
class VolumePerturbation():
def __init__(self,
lower=-1.6,
upper=1.6,
utt2ratio=None,
dbunit=True,
seed=None):
self.dbunit = dbunit
self.utt2ratio_file = utt2ratio
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
# Use the scheduled ratio for each utterances
self.utt2ratio = {}
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
# The ratio is given on runtime randomly
self.utt2ratio = None
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10**(ratio / 20)
return x * ratio
class NoiseInjection():
"""Add isotropic noise"""
def __init__(
self,
utt2noise=None,
lower=-20,
upper=-5,
utt2ratio=None,
filetype="list",
dbunit=True,
seed=None, ):
self.utt2noise_file = utt2noise
self.utt2ratio_file = utt2ratio
self.filetype = filetype
self.dbunit = dbunit
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
# Use the scheduled ratio for each utterances
self.utt2ratio = {}
with open(utt2noise, "r") as f:
for line in f:
utt, snr = line.rstrip().split(None, 1)
snr = float(snr)
self.utt2ratio[utt] = snr
else:
# The ratio is given on runtime randomly
self.utt2ratio = None
if utt2noise is not None:
self.utt2noise = {}
if filetype == "list":
with open(utt2noise, "r") as f:
for line in f:
utt, filename = line.rstrip().split(None, 1)
signal, rate = soundfile.read(filename, dtype="int16")
# Load all files in memory
self.utt2noise[utt] = (signal, rate)
elif filetype == "sound.hdf5":
self.utt2noise = SoundHDF5File(utt2noise, "r")
else:
raise ValueError(filetype)
else:
self.utt2noise = None
if utt2noise is not None and utt2ratio is not None:
if set(self.utt2ratio) != set(self.utt2noise):
raise RuntimeError("The uttids mismatch between {} and {}".
format(utt2ratio, utt2noise))
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
# 1. Get ratio of noise to signal in sound pressure level
if uttid is not None and self.utt2ratio is not None:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10**(ratio / 20)
scale = ratio * numpy.sqrt((x**2).mean())
# 2. Get noise
if self.utt2noise is not None:
# Get noise from the external source
if uttid is not None:
noise, rate = self.utt2noise[uttid]
else:
# Randomly select the noise source
noise = self.state.choice(list(self.utt2noise.values()))
# Normalize the level
noise /= numpy.sqrt((noise**2).mean())
# Adjust the noise length
diff = abs(len(x) - len(noise))
offset = self.state.randint(0, diff)
if len(noise) > len(x):
# Truncate noise
noise = noise[offset:-(diff - offset)]
else:
noise = numpy.pad(
noise, pad_width=[offset, diff - offset], mode="wrap")
else:
# Generate white noise
noise = self.state.normal(0, 1, x.shape)
# 3. Add noise to signal
return x + noise * scale
class RIRConvolve():
def __init__(self, utt2rir, filetype="list"):
self.utt2rir_file = utt2rir
self.filetype = filetype
self.utt2rir = {}
if filetype == "list":
with open(utt2rir, "r") as f:
for line in f:
utt, filename = line.rstrip().split(None, 1)
signal, rate = soundfile.read(filename, dtype="int16")
self.utt2rir[utt] = (signal, rate)
elif filetype == "sound.hdf5":
self.utt2rir = SoundHDF5File(utt2rir, "r")
else:
raise NotImplementedError(filetype)
def __repr__(self):
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if x.ndim != 1:
# Must be single channel
raise RuntimeError(
"Input x must be one dimensional array, but got {}".format(
x.shape))
rir, rate = self.utt2rir[uttid]
if rir.ndim == 2:
# FIXME(kamo): Use chainer.convolution_1d?
# return [Time, Channel]
return numpy.stack(
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
else:
return scipy.convolve(x, rir, mode="same")

@ -1,214 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation"""
import random
import numpy
from PIL import Image
from PIL.Image import BICUBIC
from paddlespeech.s2t.transform.functional import FuncTrans
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
:param numpy.ndarray x: spectrogram (time, freq)
:param int max_time_warp: maximum time frames to warp
:param bool inplace: overwrite x with the result
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
(slow, differentiable)
:returns numpy.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp
if window == 0:
return x
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if inplace:
x[:warped] = left
x[warped:] = right
return x
return numpy.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
import paddle
from espnet.utils import spec_augment
# TODO(karita): make this differentiable again
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
else:
raise NotImplementedError("unknown resize mode: " + mode +
", choose one from (PIL, sparse_image_warp).")
class TimeWarp(FuncTrans):
_func = time_warp
__doc__ = time_warp.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
"""freq mask for spec agument
:param numpy.ndarray x: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if inplace:
cloned = x
else:
cloned = x.copy()
num_mel_channels = cloned.shape[1]
fs = numpy.random.randint(0, F, size=(n_mask, 2))
for f, mask_end in fs:
f_zero = random.randrange(0, num_mel_channels - f)
mask_end += f_zero
# avoids randrange error if values are equal and range is empty
if f_zero == f_zero + f:
continue
if replace_with_zero:
cloned[:, f_zero:mask_end] = 0
else:
cloned[:, f_zero:mask_end] = cloned.mean()
return cloned
class FreqMask(FuncTrans):
_func = freq_mask
__doc__ = freq_mask.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
"""freq mask for spec agument
:param numpy.ndarray spec: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if inplace:
cloned = spec
else:
cloned = spec.copy()
len_spectro = cloned.shape[0]
ts = numpy.random.randint(0, T, size=(n_mask, 2))
for t, mask_end in ts:
# avoid randint range error
if len_spectro - t <= 0:
continue
t_zero = random.randrange(0, len_spectro - t)
# avoids randrange error if values are equal and range is empty
if t_zero == t_zero + t:
continue
mask_end += t_zero
if replace_with_zero:
cloned[t_zero:mask_end] = 0
else:
cloned[t_zero:mask_end] = cloned.mean()
return cloned
class TimeMask(FuncTrans):
_func = time_mask
__doc__ = time_mask.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def spec_augment(
x,
resize_mode="PIL",
max_time_warp=80,
max_freq_width=27,
n_freq_mask=2,
max_time_width=100,
n_time_mask=2,
inplace=True,
replace_with_zero=True, ):
"""spec agument
apply random time warping and time/freq masking
default setting is based on LD (Librispeech double) in Table 2
https://arxiv.org/pdf/1904.08779.pdf
:param numpy.ndarray x: (time, freq)
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
(slow, differentiable)
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
:param int freq_mask_width: maximum width of the random freq mask (F)
:param int n_freq_mask: the number of the random freq mask (m_F)
:param int time_mask_width: maximum width of the random time mask (T)
:param int n_time_mask: the number of the random time mask (m_T)
:param bool inplace: overwrite intermediate array
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
assert isinstance(x, numpy.ndarray)
assert x.ndim == 2
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
x = freq_mask(
x,
max_freq_width,
n_freq_mask,
inplace=inplace,
replace_with_zero=replace_with_zero, )
x = time_mask(
x,
max_time_width,
n_time_mask,
inplace=inplace,
replace_with_zero=replace_with_zero, )
return x
class SpecAugment(FuncTrans):
_func = spec_augment
__doc__ = spec_augment.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)

@ -1,475 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import librosa
import numpy as np
import paddle
from python_speech_features import logfbank
import paddlespeech.audio.compliance.kaldi as kaldi
def stft(x,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect"):
# x: [Time, Channel]
if x.ndim == 1:
single_channel = True
# x: [Time] -> [Time, Channel]
x = x[:, None]
else:
single_channel = False
x = x.astype(np.float32)
# FIXME(kamo): librosa.stft can't use multi-channel?
# x: [Time, Channel, Freq]
x = np.stack(
[
librosa.stft(
y=x[:, ch],
n_fft=n_fft,
hop_length=n_shift,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode, ).T for ch in range(x.shape[1])
],
axis=1, )
if single_channel:
# x: [Time, Channel, Freq] -> [Time, Freq]
x = x[:, 0]
return x
def istft(x, n_shift, win_length=None, window="hann", center=True):
# x: [Time, Channel, Freq]
if x.ndim == 2:
single_channel = True
# x: [Time, Freq] -> [Time, Channel, Freq]
x = x[:, None, :]
else:
single_channel = False
# x: [Time, Channel]
x = np.stack(
[
librosa.istft(
stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time]
hop_length=n_shift,
win_length=win_length,
window=window,
center=center, ) for ch in range(x.shape[1])
],
axis=1, )
if single_channel:
# x: [Time, Channel] -> [Time]
x = x[:, 0]
return x
def stft2logmelspectrogram(x_stft,
fs,
n_mels,
n_fft,
fmin=None,
fmax=None,
eps=1e-10):
# x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
# spc: (Time, Channel, Freq) or (Time, Freq)
spc = np.abs(x_stft)
# mel_basis: (Mel_freq, Freq)
mel_basis = librosa.filters.mel(
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
return lmspc
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
return spc
def logmelspectrogram(
x,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10,
pad_mode="reflect", ):
# stft: (Time, Channel, Freq) or (Time, Freq)
x_stft = stft(
x,
n_fft=n_fft,
n_shift=n_shift,
win_length=win_length,
window=window,
pad_mode=pad_mode, )
return stft2logmelspectrogram(
x_stft,
fs=fs,
n_mels=n_mels,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
eps=eps)
class Spectrogram():
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
def __repr__(self):
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, ))
def __call__(self, x):
return spectrogram(
x,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, )
class LogMelSpectrogram():
def __init__(
self,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10, ):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps, ))
def __call__(self, x):
return logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, )
class Stft2LogMelSpectrogram():
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps, ))
def __call__(self, x):
return stft2logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax, )
class Stft():
def __init__(
self,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect", ):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
self.pad_mode = pad_mode
def __repr__(self):
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode, ))
def __call__(self, x):
return stft(
x,
self.n_fft,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode, )
class IStft():
def __init__(self, n_shift, win_length=None, window="hann", center=True):
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
def __repr__(self):
return ("{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center})".format(
name=self.__class__.__name__,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center, ))
def __call__(self, x):
return istft(
x,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center, )
class LogMelSpectrogramKaldi():
def __init__(
self,
fs=16000,
n_mels=80,
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
energy_floor=0.0,
dither=0.1):
"""
The Kaldi implementation of LogMelSpectrogram
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant
Returns:
LogMelSpectrogramKaldi
"""
self.fs = fs
self.n_mels = n_mels
num_point_ms = fs / 1000
self.n_frame_length = win_length / num_point_ms
self.n_frame_shift = n_shift / num_point_ms
self.energy_floor = energy_floor
self.dither = dither
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_frame_shift=self.n_frame_shift,
n_frame_length=self.n_frame_length,
dither=self.dither, ))
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32)
mat = kaldi.fbank(
waveform,
n_mels=self.n_mels,
frame_length=self.n_frame_length,
frame_shift=self.n_frame_shift,
dither=dither,
energy_floor=self.energy_floor,
sr=self.fs)
mat = np.squeeze(mat.numpy())
return mat
class LogMelSpectrogramKaldi_decay():
def __init__(
self,
fs=16000,
n_mels=80,
n_fft=512, # fft point
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
window="povey",
fmin=20,
fmax=None,
eps=1e-10,
dither=1.0):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
if n_shift > win_length:
raise ValueError("Stride size must not be greater than "
"window size.")
self.n_shift = n_shift / fs # unit: ms
self.win_length = win_length / fs # unit: ms
self.window = window
self.fmin = fmin
if fmax is None:
fmax_ = fmax if fmax else self.fs / 2
elif fmax > int(self.fs / 2):
raise ValueError("fmax must not be greater than half of "
"sample rate.")
self.fmax = fmax_
self.eps = eps
self.remove_dc_offset = True
self.preemph = 0.97
self.dither = dither # only work in train mode
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
preemph=self.preemph,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps,
dither=self.dither, ))
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
if x.dtype in np.sctypes['float']:
# PCM32 -> PCM16
bits = np.iinfo(np.int16).bits
x = x * 2**(bits - 1)
# logfbank need PCM16 input
y = logfbank(
signal=x,
samplerate=self.fs,
winlen=self.win_length, # unit ms
winstep=self.n_shift, # unit ms
nfilt=self.n_mels,
nfft=self.n_fft,
lowfreq=self.fmin,
highfreq=self.fmax,
dither=dither,
remove_dc_offset=self.remove_dc_offset,
preemph=self.preemph,
wintype=self.window)
return y

@ -1,35 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
class TransformInterface:
"""Transform Interface"""
def __call__(self, x):
raise NotImplementedError("__call__ method is not implemented")
@classmethod
def add_arguments(cls, parser):
return parser
def __repr__(self):
return self.__class__.__name__ + "()"
class Identity(TransformInterface):
"""Identity Function"""
def __call__(self, x):
return x

@ -1,158 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Transformation module."""
import copy
import io
import logging
from collections import OrderedDict
from collections.abc import Sequence
from inspect import signature
import yaml
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
import_alias = dict(
identity="paddlespeech.s2t.transform.transform_interface:Identity",
time_warp="paddlespeech.s2t.transform.spec_augment:TimeWarp",
time_mask="paddlespeech.s2t.transform.spec_augment:TimeMask",
freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask",
spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment",
speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation",
speed_perturbation_sox="paddlespeech.s2t.transform.perturb:SpeedPerturbationSox",
volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation",
noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection",
bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation",
rir_convolve="paddlespeech.s2t.transform.perturb:RIRConvolve",
delta="paddlespeech.s2t.transform.add_deltas:AddDeltas",
cmvn="paddlespeech.s2t.transform.cmvn:CMVN",
utterance_cmvn="paddlespeech.s2t.transform.cmvn:UtteranceCMVN",
fbank="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogram",
spectrogram="paddlespeech.s2t.transform.spectrogram:Spectrogram",
stft="paddlespeech.s2t.transform.spectrogram:Stft",
istft="paddlespeech.s2t.transform.spectrogram:IStft",
stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="paddlespeech.s2t.transform.wpe:WPE",
channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector",
fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi",
cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN")
class Transformation():
"""Apply some functions to the mini-batch
Examples:
>>> kwargs = {"process": [{"type": "fbank",
... "n_mels": 80,
... "fs": 16000},
... {"type": "cmvn",
... "stats": "data/train/cmvn.ark",
... "norm_vars": True},
... {"type": "delta", "window": 2, "order": 2}]}
>>> transform = Transformation(kwargs)
>>> bs = 10
>>> xs = [np.random.randn(100, 80).astype(np.float32)
... for _ in range(bs)]
>>> xs = transform(xs)
"""
def __init__(self, conffile=None):
if conffile is not None:
if isinstance(conffile, dict):
self.conf = copy.deepcopy(conffile)
else:
with io.open(conffile, encoding="utf-8") as f:
self.conf = yaml.safe_load(f)
assert isinstance(self.conf, dict), type(self.conf)
else:
self.conf = {"mode": "sequential", "process": []}
self.functions = OrderedDict()
if self.conf.get("mode", "sequential") == "sequential":
for idx, process in enumerate(self.conf["process"]):
assert isinstance(process, dict), type(process)
opts = dict(process)
process_type = opts.pop("type")
class_obj = dynamic_import(process_type, import_alias)
# TODO(karita): assert issubclass(class_obj, TransformInterface)
try:
self.functions[idx] = class_obj(**opts)
except TypeError:
try:
signa = signature(class_obj)
except ValueError:
# Some function, e.g. built-in function, are failed
pass
else:
logging.error("Expected signature: {}({})".format(
class_obj.__name__, signa))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
def __repr__(self):
rep = "\n" + "\n".join(" {}: {}".format(k, v)
for k, v in self.functions.items())
return "{}({})".format(self.__class__.__name__, rep)
def __call__(self, xs, uttid_list=None, **kwargs):
"""Return new mini-batch
:param Union[Sequence[np.ndarray], np.ndarray] xs:
:param Union[Sequence[str], str] uttid_list:
:return: batch:
:rtype: List[np.ndarray]
"""
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx in range(len(self.conf["process"])):
func = self.functions[idx]
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logging.fatal("Catch a exception from {}th func: {}".format(
idx, func))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
if is_batch:
return xs
else:
return xs[0]

@ -1,58 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from nara_wpe.wpe import wpe
class WPE(object):
def __init__(self,
taps=10,
delay=3,
iterations=3,
psd_context=0,
statistics_mode="full"):
self.taps = taps
self.delay = delay
self.iterations = iterations
self.psd_context = psd_context
self.statistics_mode = statistics_mode
def __repr__(self):
return ("{name}(taps={taps}, delay={delay}"
"iterations={iterations}, psd_context={psd_context}, "
"statistics_mode={statistics_mode})".format(
name=self.__class__.__name__,
taps=self.taps,
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode, ))
def __call__(self, xs):
"""Return enhanced
:param np.ndarray xs: (Time, Channel, Frequency)
:return: enhanced_xs
:rtype: np.ndarray
"""
# nara_wpe.wpe: (F, C, T)
xs = wpe(
xs.transpose((2, 1, 0)),
taps=self.taps,
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode, )
return xs.transpose(2, 1, 0)
Loading…
Cancel
Save