parent
8f5e61090b
commit
c7a7b113c8
@ -1,13 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -1,54 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import numpy as np
|
||||
|
||||
|
||||
def delta(feat, window):
|
||||
assert window > 0
|
||||
delta_feat = np.zeros_like(feat)
|
||||
for i in range(1, window + 1):
|
||||
delta_feat[:-i] += i * feat[i:]
|
||||
delta_feat[i:] += -i * feat[:-i]
|
||||
delta_feat[-i:] += i * feat[-1]
|
||||
delta_feat[:i] += -i * feat[0]
|
||||
delta_feat /= 2 * sum(i**2 for i in range(1, window + 1))
|
||||
return delta_feat
|
||||
|
||||
|
||||
def add_deltas(x, window=2, order=2):
|
||||
"""
|
||||
Args:
|
||||
x (np.ndarray): speech feat, (T, D).
|
||||
|
||||
Return:
|
||||
np.ndarray: (T, (1+order)*D)
|
||||
"""
|
||||
feats = [x]
|
||||
for _ in range(order):
|
||||
feats.append(delta(feats[-1], window))
|
||||
return np.concatenate(feats, axis=1)
|
||||
|
||||
|
||||
class AddDeltas():
|
||||
def __init__(self, window=2, order=2):
|
||||
self.window = window
|
||||
self.order = order
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(window={window}, order={order}".format(
|
||||
name=self.__class__.__name__, window=self.window, order=self.order)
|
||||
|
||||
def __call__(self, x):
|
||||
return add_deltas(x, window=self.window, order=self.order)
|
@ -1,57 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import numpy
|
||||
|
||||
|
||||
class ChannelSelector():
|
||||
"""Select 1ch from multi-channel signal"""
|
||||
|
||||
def __init__(self, train_channel="random", eval_channel=0, axis=1):
|
||||
self.train_channel = train_channel
|
||||
self.eval_channel = eval_channel
|
||||
self.axis = axis
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(train_channel={train_channel}, "
|
||||
"eval_channel={eval_channel}, axis={axis})".format(
|
||||
name=self.__class__.__name__,
|
||||
train_channel=self.train_channel,
|
||||
eval_channel=self.eval_channel,
|
||||
axis=self.axis, ))
|
||||
|
||||
def __call__(self, x, train=True):
|
||||
# Assuming x: [Time, Channel] by default
|
||||
|
||||
if x.ndim <= self.axis:
|
||||
# If the dimension is insufficient, then unsqueeze
|
||||
# (e.g [Time] -> [Time, 1])
|
||||
ind = tuple(
|
||||
slice(None) if i < x.ndim else None
|
||||
for i in range(self.axis + 1))
|
||||
x = x[ind]
|
||||
|
||||
if train:
|
||||
channel = self.train_channel
|
||||
else:
|
||||
channel = self.eval_channel
|
||||
|
||||
if channel == "random":
|
||||
ch = numpy.random.randint(0, x.shape[self.axis])
|
||||
else:
|
||||
ch = channel
|
||||
|
||||
ind = tuple(
|
||||
slice(None) if i != self.axis else ch for i in range(x.ndim))
|
||||
return x[ind]
|
@ -1,201 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import io
|
||||
import json
|
||||
|
||||
import h5py
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CMVN():
|
||||
"Apply Global/Spk CMVN/iverserCMVN."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats,
|
||||
norm_means=True,
|
||||
norm_vars=False,
|
||||
filetype="mat",
|
||||
utt2spk=None,
|
||||
spk2utt=None,
|
||||
reverse=False,
|
||||
std_floor=1.0e-20, ):
|
||||
self.stats_file = stats
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.reverse = reverse
|
||||
|
||||
if isinstance(stats, dict):
|
||||
stats_dict = dict(stats)
|
||||
else:
|
||||
# Use for global CMVN
|
||||
if filetype == "mat":
|
||||
stats_dict = {None: kaldiio.load_mat(stats)}
|
||||
# Use for global CMVN
|
||||
elif filetype == "npy":
|
||||
stats_dict = {None: np.load(stats)}
|
||||
# Use for speaker CMVN
|
||||
elif filetype == "ark":
|
||||
self.accept_uttid = True
|
||||
stats_dict = dict(kaldiio.load_ark(stats))
|
||||
# Use for speaker CMVN
|
||||
elif filetype == "hdf5":
|
||||
self.accept_uttid = True
|
||||
stats_dict = h5py.File(stats)
|
||||
else:
|
||||
raise ValueError("Not supporting filetype={}".format(filetype))
|
||||
|
||||
if utt2spk is not None:
|
||||
self.utt2spk = {}
|
||||
with io.open(utt2spk, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
utt, spk = line.rstrip().split(None, 1)
|
||||
self.utt2spk[utt] = spk
|
||||
elif spk2utt is not None:
|
||||
self.utt2spk = {}
|
||||
with io.open(spk2utt, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
spk, utts = line.rstrip().split(None, 1)
|
||||
for utt in utts.split():
|
||||
self.utt2spk[utt] = spk
|
||||
else:
|
||||
self.utt2spk = None
|
||||
|
||||
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
|
||||
# and the first vector contains the sum of feats and the second is
|
||||
# the sum of squares. The last value of the first, i.e. stats[0,-1],
|
||||
# is the number of samples for this statistics.
|
||||
self.bias = {}
|
||||
self.scale = {}
|
||||
for spk, stats in stats_dict.items():
|
||||
assert len(stats) == 2, stats.shape
|
||||
|
||||
count = stats[0, -1]
|
||||
|
||||
# If the feature has two or more dimensions
|
||||
if not (np.isscalar(count) or isinstance(count, (int, float))):
|
||||
# The first is only used
|
||||
count = count.flatten()[0]
|
||||
|
||||
mean = stats[0, :-1] / count
|
||||
# V(x) = E(x^2) - (E(x))^2
|
||||
var = stats[1, :-1] / count - mean * mean
|
||||
std = np.maximum(np.sqrt(var), std_floor)
|
||||
self.bias[spk] = -mean
|
||||
self.scale[spk] = 1 / std
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(stats_file={stats_file}, "
|
||||
"norm_means={norm_means}, norm_vars={norm_vars}, "
|
||||
"reverse={reverse})".format(
|
||||
name=self.__class__.__name__,
|
||||
stats_file=self.stats_file,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars,
|
||||
reverse=self.reverse, ))
|
||||
|
||||
def __call__(self, x, uttid=None):
|
||||
if self.utt2spk is not None:
|
||||
spk = self.utt2spk[uttid]
|
||||
else:
|
||||
spk = uttid
|
||||
|
||||
if not self.reverse:
|
||||
# apply cmvn
|
||||
if self.norm_means:
|
||||
x = np.add(x, self.bias[spk])
|
||||
if self.norm_vars:
|
||||
x = np.multiply(x, self.scale[spk])
|
||||
|
||||
else:
|
||||
# apply reverse cmvn
|
||||
if self.norm_vars:
|
||||
x = np.divide(x, self.scale[spk])
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, self.bias[spk])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UtteranceCMVN():
|
||||
"Apply Utterance CMVN"
|
||||
|
||||
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.std_floor = std_floor
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
|
||||
name=self.__class__.__name__,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars, )
|
||||
|
||||
def __call__(self, x, uttid=None):
|
||||
# x: [Time, Dim]
|
||||
square_sums = (x**2).sum(axis=0)
|
||||
mean = x.mean(axis=0)
|
||||
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, mean)
|
||||
|
||||
if self.norm_vars:
|
||||
var = square_sums / x.shape[0] - mean**2
|
||||
std = np.maximum(np.sqrt(var), self.std_floor)
|
||||
x = np.divide(x, std)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GlobalCMVN():
|
||||
"Apply Global CMVN"
|
||||
|
||||
def __init__(self,
|
||||
cmvn_path,
|
||||
norm_means=True,
|
||||
norm_vars=True,
|
||||
std_floor=1.0e-20):
|
||||
# cmvn_path: Option[str, dict]
|
||||
cmvn = cmvn_path
|
||||
self.cmvn = cmvn
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.std_floor = std_floor
|
||||
if isinstance(cmvn, dict):
|
||||
cmvn_stats = cmvn
|
||||
else:
|
||||
with open(cmvn) as f:
|
||||
cmvn_stats = json.load(f)
|
||||
self.count = cmvn_stats['frame_num']
|
||||
self.mean = np.array(cmvn_stats['mean_stat']) / self.count
|
||||
self.square_sums = np.array(cmvn_stats['var_stat'])
|
||||
self.var = self.square_sums / self.count - self.mean**2
|
||||
self.std = np.maximum(np.sqrt(self.var), self.std_floor)
|
||||
|
||||
def __repr__(self):
|
||||
return f"""{self.__class__.__name__}(
|
||||
cmvn_path={self.cmvn},
|
||||
norm_means={self.norm_means},
|
||||
norm_vars={self.norm_vars},)"""
|
||||
|
||||
def __call__(self, x, uttid=None):
|
||||
# x: [Time, Dim]
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, self.mean)
|
||||
|
||||
if self.norm_vars:
|
||||
x = np.divide(x, self.std)
|
||||
return x
|
@ -1,86 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import inspect
|
||||
|
||||
from paddlespeech.s2t.transform.transform_interface import TransformInterface
|
||||
from paddlespeech.s2t.utils.check_kwargs import check_kwargs
|
||||
|
||||
|
||||
class FuncTrans(TransformInterface):
|
||||
"""Functional Transformation
|
||||
|
||||
WARNING:
|
||||
Builtin or C/C++ functions may not work properly
|
||||
because this class heavily depends on the `inspect` module.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> def foo_bar(x, a=1, b=2):
|
||||
... '''Foo bar
|
||||
... :param x: input
|
||||
... :param int a: default 1
|
||||
... :param int b: default 2
|
||||
... '''
|
||||
... return x + a - b
|
||||
|
||||
|
||||
>>> class FooBar(FuncTrans):
|
||||
... _func = foo_bar
|
||||
... __doc__ = foo_bar.__doc__
|
||||
"""
|
||||
|
||||
_func = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
check_kwargs(self.func, kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.func(x, **self.kwargs)
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
fname = cls._func.__name__.replace("_", "-")
|
||||
group = parser.add_argument_group(fname + " transformation setting")
|
||||
for k, v in cls.default_params().items():
|
||||
# TODO(karita): get help and choices from docstring?
|
||||
attr = k.replace("_", "-")
|
||||
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
|
||||
return parser
|
||||
|
||||
@property
|
||||
def func(self):
|
||||
return type(self)._func
|
||||
|
||||
@classmethod
|
||||
def default_params(cls):
|
||||
try:
|
||||
d = dict(inspect.signature(cls._func).parameters)
|
||||
except ValueError:
|
||||
d = dict()
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in d.items() if v.default != inspect.Parameter.empty
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
params = self.default_params()
|
||||
params.update(**self.kwargs)
|
||||
ret = self.__class__.__name__ + "("
|
||||
if len(params) == 0:
|
||||
return ret + ")"
|
||||
for k, v in params.items():
|
||||
ret += "{}={}, ".format(k, v)
|
||||
return ret[:-2] + ")"
|
@ -1,471 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import librosa
|
||||
import numpy
|
||||
import scipy
|
||||
import soundfile
|
||||
|
||||
from paddlespeech.s2t.io.reader import SoundHDF5File
|
||||
|
||||
|
||||
class SpeedPerturbation():
|
||||
"""SpeedPerturbation
|
||||
|
||||
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||
and sox-speed just to resample the input,
|
||||
i.e pitch and tempo are changed both.
|
||||
|
||||
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||
|
||||
Warning:
|
||||
This function is very slow because of resampling.
|
||||
I recommmend to apply speed-perturb outside the training using sox.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower=0.9,
|
||||
upper=1.1,
|
||||
utt2ratio=None,
|
||||
keep_length=True,
|
||||
res_type="kaiser_best",
|
||||
seed=None, ):
|
||||
self.res_type = res_type
|
||||
self.keep_length = keep_length
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
self.utt2ratio = {}
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
self.utt2ratio = None
|
||||
# The ratio is given on runtime randomly
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
|
||||
self.__class__.__name__,
|
||||
self.lower,
|
||||
self.upper,
|
||||
self.keep_length,
|
||||
self.res_type, )
|
||||
else:
|
||||
return "{}({}, res_type={})".format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.res_type)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
x = x.astype(numpy.float32)
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
# Note1: resample requires the sampling-rate of input and output,
|
||||
# but actually only the ratio is used.
|
||||
y = librosa.resample(
|
||||
x, orig_sr=ratio, target_sr=1, res_type=self.res_type)
|
||||
|
||||
if self.keep_length:
|
||||
diff = abs(len(x) - len(y))
|
||||
if len(y) > len(x):
|
||||
# Truncate noise
|
||||
y = y[diff // 2:-((diff + 1) // 2)]
|
||||
elif len(y) < len(x):
|
||||
# Assume the time-axis is the first: (Time, Channel)
|
||||
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||
(0, 0) for _ in range(y.ndim - 1)
|
||||
]
|
||||
y = numpy.pad(
|
||||
y, pad_width=pad_width, constant_values=0, mode="constant")
|
||||
return y
|
||||
|
||||
|
||||
class SpeedPerturbationSox():
|
||||
"""SpeedPerturbationSox
|
||||
|
||||
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||
and sox-speed just to resample the input,
|
||||
i.e pitch and tempo are changed both.
|
||||
|
||||
To speed up or slow down the sound of a file,
|
||||
use speed to modify the pitch and the duration of the file.
|
||||
This raises the speed and reduces the time.
|
||||
The default factor is 1.0 which makes no change to the audio.
|
||||
2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher.
|
||||
|
||||
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||
|
||||
tempo option:
|
||||
sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9
|
||||
|
||||
speed option:
|
||||
sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9
|
||||
|
||||
If we use speed option like above, the pitch of audio also will be changed,
|
||||
but the tempo option does not change the pitch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower=0.9,
|
||||
upper=1.1,
|
||||
utt2ratio=None,
|
||||
keep_length=True,
|
||||
sr=16000,
|
||||
seed=None, ):
|
||||
self.sr = sr
|
||||
self.keep_length = keep_length
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
try:
|
||||
import soxbindings as sox
|
||||
except ImportError:
|
||||
try:
|
||||
from paddlespeech.s2t.utils import dynamic_pip_install
|
||||
package = "sox"
|
||||
dynamic_pip_install.install(package)
|
||||
package = "soxbindings"
|
||||
if sys.platform != "win32":
|
||||
dynamic_pip_install.install(package)
|
||||
import soxbindings as sox
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Can not install soxbindings on your system.")
|
||||
self.sox = sox
|
||||
|
||||
if utt2ratio is not None:
|
||||
self.utt2ratio = {}
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
self.utt2ratio = None
|
||||
# The ratio is given on runtime randomly
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return f"""{self.__class__.__name__}(
|
||||
lower={self.lower},
|
||||
upper={self.upper},
|
||||
keep_length={self.keep_length},
|
||||
sample_rate={self.sr})"""
|
||||
|
||||
else:
|
||||
return f"""{self.__class__.__name__}(
|
||||
utt2ratio={self.utt2ratio_file},
|
||||
sample_rate={self.sr})"""
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
tfm = self.sox.Transformer()
|
||||
tfm.set_globals(multithread=False)
|
||||
tfm.speed(ratio)
|
||||
y = tfm.build_array(input_array=x, sample_rate_in=self.sr)
|
||||
|
||||
if self.keep_length:
|
||||
diff = abs(len(x) - len(y))
|
||||
if len(y) > len(x):
|
||||
# Truncate noise
|
||||
y = y[diff // 2:-((diff + 1) // 2)]
|
||||
elif len(y) < len(x):
|
||||
# Assume the time-axis is the first: (Time, Channel)
|
||||
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||
(0, 0) for _ in range(y.ndim - 1)
|
||||
]
|
||||
y = numpy.pad(
|
||||
y, pad_width=pad_width, constant_values=0, mode="constant")
|
||||
|
||||
if y.ndim == 2 and x.ndim == 1:
|
||||
# (T, C) -> (T)
|
||||
y = y.sequence(1)
|
||||
return y
|
||||
|
||||
|
||||
class BandpassPerturbation():
|
||||
"""BandpassPerturbation
|
||||
|
||||
Randomly dropout along the frequency axis.
|
||||
|
||||
The original idea comes from the following:
|
||||
"randomly-selected frequency band was cut off under the constraint of
|
||||
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
|
||||
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
|
||||
everyday home environments using multiple microphone arrays;
|
||||
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )):
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
# x_stft: (Time, Channel, Freq)
|
||||
self.axes = axes
|
||||
|
||||
def __repr__(self):
|
||||
return "{}(lower={}, upper={})".format(self.__class__.__name__,
|
||||
self.lower, self.upper)
|
||||
|
||||
def __call__(self, x_stft, uttid=None, train=True):
|
||||
if not train:
|
||||
return x_stft
|
||||
|
||||
if x_stft.ndim == 1:
|
||||
raise RuntimeError("Input in time-freq domain: "
|
||||
"(Time, Channel, Freq) or (Time, Freq)")
|
||||
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
|
||||
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
|
||||
|
||||
mask = self.state.randn(*shape) > ratio
|
||||
x_stft *= mask
|
||||
return x_stft
|
||||
|
||||
|
||||
class VolumePerturbation():
|
||||
def __init__(self,
|
||||
lower=-1.6,
|
||||
upper=1.6,
|
||||
utt2ratio=None,
|
||||
dbunit=True,
|
||||
seed=None):
|
||||
self.dbunit = dbunit
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio = {}
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
# The ratio is given on runtime randomly
|
||||
self.utt2ratio = None
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper, self.dbunit)
|
||||
else:
|
||||
return '{}("{}", dbunit={})'.format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
if self.dbunit:
|
||||
ratio = 10**(ratio / 20)
|
||||
return x * ratio
|
||||
|
||||
|
||||
class NoiseInjection():
|
||||
"""Add isotropic noise"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
utt2noise=None,
|
||||
lower=-20,
|
||||
upper=-5,
|
||||
utt2ratio=None,
|
||||
filetype="list",
|
||||
dbunit=True,
|
||||
seed=None, ):
|
||||
self.utt2noise_file = utt2noise
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.filetype = filetype
|
||||
self.dbunit = dbunit
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio = {}
|
||||
with open(utt2noise, "r") as f:
|
||||
for line in f:
|
||||
utt, snr = line.rstrip().split(None, 1)
|
||||
snr = float(snr)
|
||||
self.utt2ratio[utt] = snr
|
||||
else:
|
||||
# The ratio is given on runtime randomly
|
||||
self.utt2ratio = None
|
||||
|
||||
if utt2noise is not None:
|
||||
self.utt2noise = {}
|
||||
if filetype == "list":
|
||||
with open(utt2noise, "r") as f:
|
||||
for line in f:
|
||||
utt, filename = line.rstrip().split(None, 1)
|
||||
signal, rate = soundfile.read(filename, dtype="int16")
|
||||
# Load all files in memory
|
||||
self.utt2noise[utt] = (signal, rate)
|
||||
|
||||
elif filetype == "sound.hdf5":
|
||||
self.utt2noise = SoundHDF5File(utt2noise, "r")
|
||||
else:
|
||||
raise ValueError(filetype)
|
||||
else:
|
||||
self.utt2noise = None
|
||||
|
||||
if utt2noise is not None and utt2ratio is not None:
|
||||
if set(self.utt2ratio) != set(self.utt2noise):
|
||||
raise RuntimeError("The uttids mismatch between {} and {}".
|
||||
format(utt2ratio, utt2noise))
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper, self.dbunit)
|
||||
else:
|
||||
return '{}("{}", dbunit={})'.format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
# 1. Get ratio of noise to signal in sound pressure level
|
||||
if uttid is not None and self.utt2ratio is not None:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
if self.dbunit:
|
||||
ratio = 10**(ratio / 20)
|
||||
scale = ratio * numpy.sqrt((x**2).mean())
|
||||
|
||||
# 2. Get noise
|
||||
if self.utt2noise is not None:
|
||||
# Get noise from the external source
|
||||
if uttid is not None:
|
||||
noise, rate = self.utt2noise[uttid]
|
||||
else:
|
||||
# Randomly select the noise source
|
||||
noise = self.state.choice(list(self.utt2noise.values()))
|
||||
# Normalize the level
|
||||
noise /= numpy.sqrt((noise**2).mean())
|
||||
|
||||
# Adjust the noise length
|
||||
diff = abs(len(x) - len(noise))
|
||||
offset = self.state.randint(0, diff)
|
||||
if len(noise) > len(x):
|
||||
# Truncate noise
|
||||
noise = noise[offset:-(diff - offset)]
|
||||
else:
|
||||
noise = numpy.pad(
|
||||
noise, pad_width=[offset, diff - offset], mode="wrap")
|
||||
|
||||
else:
|
||||
# Generate white noise
|
||||
noise = self.state.normal(0, 1, x.shape)
|
||||
|
||||
# 3. Add noise to signal
|
||||
return x + noise * scale
|
||||
|
||||
|
||||
class RIRConvolve():
|
||||
def __init__(self, utt2rir, filetype="list"):
|
||||
self.utt2rir_file = utt2rir
|
||||
self.filetype = filetype
|
||||
|
||||
self.utt2rir = {}
|
||||
if filetype == "list":
|
||||
with open(utt2rir, "r") as f:
|
||||
for line in f:
|
||||
utt, filename = line.rstrip().split(None, 1)
|
||||
signal, rate = soundfile.read(filename, dtype="int16")
|
||||
self.utt2rir[utt] = (signal, rate)
|
||||
|
||||
elif filetype == "sound.hdf5":
|
||||
self.utt2rir = SoundHDF5File(utt2rir, "r")
|
||||
else:
|
||||
raise NotImplementedError(filetype)
|
||||
|
||||
def __repr__(self):
|
||||
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
if x.ndim != 1:
|
||||
# Must be single channel
|
||||
raise RuntimeError(
|
||||
"Input x must be one dimensional array, but got {}".format(
|
||||
x.shape))
|
||||
|
||||
rir, rate = self.utt2rir[uttid]
|
||||
if rir.ndim == 2:
|
||||
# FIXME(kamo): Use chainer.convolution_1d?
|
||||
# return [Time, Channel]
|
||||
return numpy.stack(
|
||||
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
|
||||
else:
|
||||
return scipy.convolve(x, rir, mode="same")
|
@ -1,214 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
"""Spec Augment module for preprocessing i.e., data augmentation"""
|
||||
import random
|
||||
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from PIL.Image import BICUBIC
|
||||
|
||||
from paddlespeech.s2t.transform.functional import FuncTrans
|
||||
|
||||
|
||||
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
|
||||
"""time warp for spec augment
|
||||
|
||||
move random center frame by the random width ~ uniform(-window, window)
|
||||
:param numpy.ndarray x: spectrogram (time, freq)
|
||||
:param int max_time_warp: maximum time frames to warp
|
||||
:param bool inplace: overwrite x with the result
|
||||
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
|
||||
(slow, differentiable)
|
||||
:returns numpy.ndarray: time warped spectrogram (time, freq)
|
||||
"""
|
||||
window = max_time_warp
|
||||
if window == 0:
|
||||
return x
|
||||
|
||||
if mode == "PIL":
|
||||
t = x.shape[0]
|
||||
if t - window <= window:
|
||||
return x
|
||||
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
|
||||
center = random.randrange(window, t - window)
|
||||
warped = random.randrange(center - window, center +
|
||||
window) + 1 # 1 ... t - 1
|
||||
|
||||
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
|
||||
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
|
||||
BICUBIC)
|
||||
if inplace:
|
||||
x[:warped] = left
|
||||
x[warped:] = right
|
||||
return x
|
||||
return numpy.concatenate((left, right), 0)
|
||||
elif mode == "sparse_image_warp":
|
||||
import paddle
|
||||
|
||||
from espnet.utils import spec_augment
|
||||
|
||||
# TODO(karita): make this differentiable again
|
||||
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
|
||||
else:
|
||||
raise NotImplementedError("unknown resize mode: " + mode +
|
||||
", choose one from (PIL, sparse_image_warp).")
|
||||
|
||||
|
||||
class TimeWarp(FuncTrans):
|
||||
_func = time_warp
|
||||
__doc__ = time_warp.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
|
||||
"""freq mask for spec agument
|
||||
|
||||
:param numpy.ndarray x: (time, freq)
|
||||
:param int n_mask: the number of masks
|
||||
:param bool inplace: overwrite
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
if inplace:
|
||||
cloned = x
|
||||
else:
|
||||
cloned = x.copy()
|
||||
|
||||
num_mel_channels = cloned.shape[1]
|
||||
fs = numpy.random.randint(0, F, size=(n_mask, 2))
|
||||
|
||||
for f, mask_end in fs:
|
||||
f_zero = random.randrange(0, num_mel_channels - f)
|
||||
mask_end += f_zero
|
||||
|
||||
# avoids randrange error if values are equal and range is empty
|
||||
if f_zero == f_zero + f:
|
||||
continue
|
||||
|
||||
if replace_with_zero:
|
||||
cloned[:, f_zero:mask_end] = 0
|
||||
else:
|
||||
cloned[:, f_zero:mask_end] = cloned.mean()
|
||||
return cloned
|
||||
|
||||
|
||||
class FreqMask(FuncTrans):
|
||||
_func = freq_mask
|
||||
__doc__ = freq_mask.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
|
||||
"""freq mask for spec agument
|
||||
|
||||
:param numpy.ndarray spec: (time, freq)
|
||||
:param int n_mask: the number of masks
|
||||
:param bool inplace: overwrite
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
if inplace:
|
||||
cloned = spec
|
||||
else:
|
||||
cloned = spec.copy()
|
||||
len_spectro = cloned.shape[0]
|
||||
ts = numpy.random.randint(0, T, size=(n_mask, 2))
|
||||
for t, mask_end in ts:
|
||||
# avoid randint range error
|
||||
if len_spectro - t <= 0:
|
||||
continue
|
||||
t_zero = random.randrange(0, len_spectro - t)
|
||||
|
||||
# avoids randrange error if values are equal and range is empty
|
||||
if t_zero == t_zero + t:
|
||||
continue
|
||||
|
||||
mask_end += t_zero
|
||||
if replace_with_zero:
|
||||
cloned[t_zero:mask_end] = 0
|
||||
else:
|
||||
cloned[t_zero:mask_end] = cloned.mean()
|
||||
return cloned
|
||||
|
||||
|
||||
class TimeMask(FuncTrans):
|
||||
_func = time_mask
|
||||
__doc__ = time_mask.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def spec_augment(
|
||||
x,
|
||||
resize_mode="PIL",
|
||||
max_time_warp=80,
|
||||
max_freq_width=27,
|
||||
n_freq_mask=2,
|
||||
max_time_width=100,
|
||||
n_time_mask=2,
|
||||
inplace=True,
|
||||
replace_with_zero=True, ):
|
||||
"""spec agument
|
||||
|
||||
apply random time warping and time/freq masking
|
||||
default setting is based on LD (Librispeech double) in Table 2
|
||||
https://arxiv.org/pdf/1904.08779.pdf
|
||||
|
||||
:param numpy.ndarray x: (time, freq)
|
||||
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
|
||||
(slow, differentiable)
|
||||
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
|
||||
:param int freq_mask_width: maximum width of the random freq mask (F)
|
||||
:param int n_freq_mask: the number of the random freq mask (m_F)
|
||||
:param int time_mask_width: maximum width of the random time mask (T)
|
||||
:param int n_time_mask: the number of the random time mask (m_T)
|
||||
:param bool inplace: overwrite intermediate array
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
assert isinstance(x, numpy.ndarray)
|
||||
assert x.ndim == 2
|
||||
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
|
||||
x = freq_mask(
|
||||
x,
|
||||
max_freq_width,
|
||||
n_freq_mask,
|
||||
inplace=inplace,
|
||||
replace_with_zero=replace_with_zero, )
|
||||
x = time_mask(
|
||||
x,
|
||||
max_time_width,
|
||||
n_time_mask,
|
||||
inplace=inplace,
|
||||
replace_with_zero=replace_with_zero, )
|
||||
return x
|
||||
|
||||
|
||||
class SpecAugment(FuncTrans):
|
||||
_func = spec_augment
|
||||
__doc__ = spec_augment.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
@ -1,475 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
from python_speech_features import logfbank
|
||||
|
||||
import paddlespeech.audio.compliance.kaldi as kaldi
|
||||
|
||||
|
||||
def stft(x,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
center=True,
|
||||
pad_mode="reflect"):
|
||||
# x: [Time, Channel]
|
||||
if x.ndim == 1:
|
||||
single_channel = True
|
||||
# x: [Time] -> [Time, Channel]
|
||||
x = x[:, None]
|
||||
else:
|
||||
single_channel = False
|
||||
x = x.astype(np.float32)
|
||||
|
||||
# FIXME(kamo): librosa.stft can't use multi-channel?
|
||||
# x: [Time, Channel, Freq]
|
||||
x = np.stack(
|
||||
[
|
||||
librosa.stft(
|
||||
y=x[:, ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center,
|
||||
pad_mode=pad_mode, ).T for ch in range(x.shape[1])
|
||||
],
|
||||
axis=1, )
|
||||
|
||||
if single_channel:
|
||||
# x: [Time, Channel, Freq] -> [Time, Freq]
|
||||
x = x[:, 0]
|
||||
return x
|
||||
|
||||
|
||||
def istft(x, n_shift, win_length=None, window="hann", center=True):
|
||||
# x: [Time, Channel, Freq]
|
||||
if x.ndim == 2:
|
||||
single_channel = True
|
||||
# x: [Time, Freq] -> [Time, Channel, Freq]
|
||||
x = x[:, None, :]
|
||||
else:
|
||||
single_channel = False
|
||||
|
||||
# x: [Time, Channel]
|
||||
x = np.stack(
|
||||
[
|
||||
librosa.istft(
|
||||
stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time]
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center, ) for ch in range(x.shape[1])
|
||||
],
|
||||
axis=1, )
|
||||
|
||||
if single_channel:
|
||||
# x: [Time, Channel] -> [Time]
|
||||
x = x[:, 0]
|
||||
return x
|
||||
|
||||
|
||||
def stft2logmelspectrogram(x_stft,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10):
|
||||
# x_stft: (Time, Channel, Freq) or (Time, Freq)
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
|
||||
# spc: (Time, Channel, Freq) or (Time, Freq)
|
||||
spc = np.abs(x_stft)
|
||||
# mel_basis: (Mel_freq, Freq)
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
|
||||
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
||||
|
||||
return lmspc
|
||||
|
||||
|
||||
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
|
||||
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
|
||||
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
|
||||
return spc
|
||||
|
||||
|
||||
def logmelspectrogram(
|
||||
x,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10,
|
||||
pad_mode="reflect", ):
|
||||
# stft: (Time, Channel, Freq) or (Time, Freq)
|
||||
x_stft = stft(
|
||||
x,
|
||||
n_fft=n_fft,
|
||||
n_shift=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
pad_mode=pad_mode, )
|
||||
|
||||
return stft2logmelspectrogram(
|
||||
x_stft,
|
||||
fs=fs,
|
||||
n_mels=n_mels,
|
||||
n_fft=n_fft,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
eps=eps)
|
||||
|
||||
|
||||
class Spectrogram():
|
||||
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return spectrogram(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window, )
|
||||
|
||||
|
||||
class LogMelSpectrogram():
|
||||
def __init__(
|
||||
self,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10, ):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.eps = eps
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"n_shift={n_shift}, win_length={win_length}, window={window}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return logmelspectrogram(
|
||||
x,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window, )
|
||||
|
||||
|
||||
class Stft2LogMelSpectrogram():
|
||||
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.eps = eps
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return stft2logmelspectrogram(
|
||||
x,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax, )
|
||||
|
||||
|
||||
class Stft():
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
center=True,
|
||||
pad_mode="reflect", ):
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window},"
|
||||
"center={center}, pad_mode={pad_mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return stft(
|
||||
x,
|
||||
self.n_fft,
|
||||
self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode, )
|
||||
|
||||
|
||||
class IStft():
|
||||
def __init__(self, n_shift, win_length=None, window="hann", center=True):
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.center = center
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window},"
|
||||
"center={center})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center, ))
|
||||
|
||||
def __call__(self, x):
|
||||
return istft(
|
||||
x,
|
||||
self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center, )
|
||||
|
||||
|
||||
class LogMelSpectrogramKaldi():
|
||||
def __init__(
|
||||
self,
|
||||
fs=16000,
|
||||
n_mels=80,
|
||||
n_shift=160, # unit:sample, 10ms
|
||||
win_length=400, # unit:sample, 25ms
|
||||
energy_floor=0.0,
|
||||
dither=0.1):
|
||||
"""
|
||||
The Kaldi implementation of LogMelSpectrogram
|
||||
Args:
|
||||
fs (int): sample rate of the audio
|
||||
n_mels (int): number of mel filter banks
|
||||
n_shift (int): number of points in a frame shift
|
||||
win_length (int): number of points in a frame windows
|
||||
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
|
||||
dither (float): Dithering constant
|
||||
|
||||
Returns:
|
||||
LogMelSpectrogramKaldi
|
||||
"""
|
||||
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
num_point_ms = fs / 1000
|
||||
self.n_frame_length = win_length / num_point_ms
|
||||
self.n_frame_shift = n_shift / num_point_ms
|
||||
self.energy_floor = energy_floor
|
||||
self.dither = dither
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, "
|
||||
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
|
||||
"dither={dither}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_frame_shift=self.n_frame_shift,
|
||||
n_frame_length=self.n_frame_length,
|
||||
dither=self.dither, ))
|
||||
|
||||
def __call__(self, x, train):
|
||||
"""
|
||||
Args:
|
||||
x (np.ndarray): shape (Ti,)
|
||||
train (bool): True, train mode.
|
||||
|
||||
Raises:
|
||||
ValueError: not support (Ti, C)
|
||||
|
||||
Returns:
|
||||
np.ndarray: (T, D)
|
||||
"""
|
||||
dither = self.dither if train else 0.0
|
||||
if x.ndim != 1:
|
||||
raise ValueError("Not support x: [Time, Channel]")
|
||||
waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32)
|
||||
mat = kaldi.fbank(
|
||||
waveform,
|
||||
n_mels=self.n_mels,
|
||||
frame_length=self.n_frame_length,
|
||||
frame_shift=self.n_frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=self.energy_floor,
|
||||
sr=self.fs)
|
||||
mat = np.squeeze(mat.numpy())
|
||||
return mat
|
||||
|
||||
|
||||
class LogMelSpectrogramKaldi_decay():
|
||||
def __init__(
|
||||
self,
|
||||
fs=16000,
|
||||
n_mels=80,
|
||||
n_fft=512, # fft point
|
||||
n_shift=160, # unit:sample, 10ms
|
||||
win_length=400, # unit:sample, 25ms
|
||||
window="povey",
|
||||
fmin=20,
|
||||
fmax=None,
|
||||
eps=1e-10,
|
||||
dither=1.0):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
if n_shift > win_length:
|
||||
raise ValueError("Stride size must not be greater than "
|
||||
"window size.")
|
||||
self.n_shift = n_shift / fs # unit: ms
|
||||
self.win_length = win_length / fs # unit: ms
|
||||
|
||||
self.window = window
|
||||
self.fmin = fmin
|
||||
if fmax is None:
|
||||
fmax_ = fmax if fmax else self.fs / 2
|
||||
elif fmax > int(self.fs / 2):
|
||||
raise ValueError("fmax must not be greater than half of "
|
||||
"sample rate.")
|
||||
self.fmax = fmax_
|
||||
|
||||
self.eps = eps
|
||||
self.remove_dc_offset = True
|
||||
self.preemph = 0.97
|
||||
self.dither = dither # only work in train mode
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
preemph=self.preemph,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps,
|
||||
dither=self.dither, ))
|
||||
|
||||
def __call__(self, x, train):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x (np.ndarray): shape (Ti,)
|
||||
train (bool): True, train mode.
|
||||
|
||||
Raises:
|
||||
ValueError: not support (Ti, C)
|
||||
|
||||
Returns:
|
||||
np.ndarray: (T, D)
|
||||
"""
|
||||
dither = self.dither if train else 0.0
|
||||
if x.ndim != 1:
|
||||
raise ValueError("Not support x: [Time, Channel]")
|
||||
|
||||
if x.dtype in np.sctypes['float']:
|
||||
# PCM32 -> PCM16
|
||||
bits = np.iinfo(np.int16).bits
|
||||
x = x * 2**(bits - 1)
|
||||
|
||||
# logfbank need PCM16 input
|
||||
y = logfbank(
|
||||
signal=x,
|
||||
samplerate=self.fs,
|
||||
winlen=self.win_length, # unit ms
|
||||
winstep=self.n_shift, # unit ms
|
||||
nfilt=self.n_mels,
|
||||
nfft=self.n_fft,
|
||||
lowfreq=self.fmin,
|
||||
highfreq=self.fmax,
|
||||
dither=dither,
|
||||
remove_dc_offset=self.remove_dc_offset,
|
||||
preemph=self.preemph,
|
||||
wintype=self.window)
|
||||
return y
|
@ -1,35 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
|
||||
|
||||
class TransformInterface:
|
||||
"""Transform Interface"""
|
||||
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError("__call__ method is not implemented")
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
return parser
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
class Identity(TransformInterface):
|
||||
"""Identity Function"""
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
@ -1,158 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
"""Transformation module."""
|
||||
import copy
|
||||
import io
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from inspect import signature
|
||||
|
||||
import yaml
|
||||
|
||||
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
|
||||
|
||||
import_alias = dict(
|
||||
identity="paddlespeech.s2t.transform.transform_interface:Identity",
|
||||
time_warp="paddlespeech.s2t.transform.spec_augment:TimeWarp",
|
||||
time_mask="paddlespeech.s2t.transform.spec_augment:TimeMask",
|
||||
freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask",
|
||||
spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment",
|
||||
speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation",
|
||||
speed_perturbation_sox="paddlespeech.s2t.transform.perturb:SpeedPerturbationSox",
|
||||
volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation",
|
||||
noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection",
|
||||
bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation",
|
||||
rir_convolve="paddlespeech.s2t.transform.perturb:RIRConvolve",
|
||||
delta="paddlespeech.s2t.transform.add_deltas:AddDeltas",
|
||||
cmvn="paddlespeech.s2t.transform.cmvn:CMVN",
|
||||
utterance_cmvn="paddlespeech.s2t.transform.cmvn:UtteranceCMVN",
|
||||
fbank="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogram",
|
||||
spectrogram="paddlespeech.s2t.transform.spectrogram:Spectrogram",
|
||||
stft="paddlespeech.s2t.transform.spectrogram:Stft",
|
||||
istft="paddlespeech.s2t.transform.spectrogram:IStft",
|
||||
stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram",
|
||||
wpe="paddlespeech.s2t.transform.wpe:WPE",
|
||||
channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector",
|
||||
fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi",
|
||||
cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN")
|
||||
|
||||
|
||||
class Transformation():
|
||||
"""Apply some functions to the mini-batch
|
||||
|
||||
Examples:
|
||||
>>> kwargs = {"process": [{"type": "fbank",
|
||||
... "n_mels": 80,
|
||||
... "fs": 16000},
|
||||
... {"type": "cmvn",
|
||||
... "stats": "data/train/cmvn.ark",
|
||||
... "norm_vars": True},
|
||||
... {"type": "delta", "window": 2, "order": 2}]}
|
||||
>>> transform = Transformation(kwargs)
|
||||
>>> bs = 10
|
||||
>>> xs = [np.random.randn(100, 80).astype(np.float32)
|
||||
... for _ in range(bs)]
|
||||
>>> xs = transform(xs)
|
||||
"""
|
||||
|
||||
def __init__(self, conffile=None):
|
||||
if conffile is not None:
|
||||
if isinstance(conffile, dict):
|
||||
self.conf = copy.deepcopy(conffile)
|
||||
else:
|
||||
with io.open(conffile, encoding="utf-8") as f:
|
||||
self.conf = yaml.safe_load(f)
|
||||
assert isinstance(self.conf, dict), type(self.conf)
|
||||
else:
|
||||
self.conf = {"mode": "sequential", "process": []}
|
||||
|
||||
self.functions = OrderedDict()
|
||||
if self.conf.get("mode", "sequential") == "sequential":
|
||||
for idx, process in enumerate(self.conf["process"]):
|
||||
assert isinstance(process, dict), type(process)
|
||||
opts = dict(process)
|
||||
process_type = opts.pop("type")
|
||||
class_obj = dynamic_import(process_type, import_alias)
|
||||
# TODO(karita): assert issubclass(class_obj, TransformInterface)
|
||||
try:
|
||||
self.functions[idx] = class_obj(**opts)
|
||||
except TypeError:
|
||||
try:
|
||||
signa = signature(class_obj)
|
||||
except ValueError:
|
||||
# Some function, e.g. built-in function, are failed
|
||||
pass
|
||||
else:
|
||||
logging.error("Expected signature: {}({})".format(
|
||||
class_obj.__name__, signa))
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supporting mode={}".format(self.conf["mode"]))
|
||||
|
||||
def __repr__(self):
|
||||
rep = "\n" + "\n".join(" {}: {}".format(k, v)
|
||||
for k, v in self.functions.items())
|
||||
return "{}({})".format(self.__class__.__name__, rep)
|
||||
|
||||
def __call__(self, xs, uttid_list=None, **kwargs):
|
||||
"""Return new mini-batch
|
||||
|
||||
:param Union[Sequence[np.ndarray], np.ndarray] xs:
|
||||
:param Union[Sequence[str], str] uttid_list:
|
||||
:return: batch:
|
||||
:rtype: List[np.ndarray]
|
||||
"""
|
||||
if not isinstance(xs, Sequence):
|
||||
is_batch = False
|
||||
xs = [xs]
|
||||
else:
|
||||
is_batch = True
|
||||
|
||||
if isinstance(uttid_list, str):
|
||||
uttid_list = [uttid_list for _ in range(len(xs))]
|
||||
|
||||
if self.conf.get("mode", "sequential") == "sequential":
|
||||
for idx in range(len(self.conf["process"])):
|
||||
func = self.functions[idx]
|
||||
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
|
||||
# Derive only the args which the func has
|
||||
try:
|
||||
param = signature(func).parameters
|
||||
except ValueError:
|
||||
# Some function, e.g. built-in function, are failed
|
||||
param = {}
|
||||
_kwargs = {k: v for k, v in kwargs.items() if k in param}
|
||||
try:
|
||||
if uttid_list is not None and "uttid" in param:
|
||||
xs = [
|
||||
func(x, u, **_kwargs)
|
||||
for x, u in zip(xs, uttid_list)
|
||||
]
|
||||
else:
|
||||
xs = [func(x, **_kwargs) for x in xs]
|
||||
except Exception:
|
||||
logging.fatal("Catch a exception from {}th func: {}".format(
|
||||
idx, func))
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supporting mode={}".format(self.conf["mode"]))
|
||||
|
||||
if is_batch:
|
||||
return xs
|
||||
else:
|
||||
return xs[0]
|
@ -1,58 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
from nara_wpe.wpe import wpe
|
||||
|
||||
|
||||
class WPE(object):
|
||||
def __init__(self,
|
||||
taps=10,
|
||||
delay=3,
|
||||
iterations=3,
|
||||
psd_context=0,
|
||||
statistics_mode="full"):
|
||||
self.taps = taps
|
||||
self.delay = delay
|
||||
self.iterations = iterations
|
||||
self.psd_context = psd_context
|
||||
self.statistics_mode = statistics_mode
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(taps={taps}, delay={delay}"
|
||||
"iterations={iterations}, psd_context={psd_context}, "
|
||||
"statistics_mode={statistics_mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
iterations=self.iterations,
|
||||
psd_context=self.psd_context,
|
||||
statistics_mode=self.statistics_mode, ))
|
||||
|
||||
def __call__(self, xs):
|
||||
"""Return enhanced
|
||||
|
||||
:param np.ndarray xs: (Time, Channel, Frequency)
|
||||
:return: enhanced_xs
|
||||
:rtype: np.ndarray
|
||||
|
||||
"""
|
||||
# nara_wpe.wpe: (F, C, T)
|
||||
xs = wpe(
|
||||
xs.transpose((2, 1, 0)),
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
iterations=self.iterations,
|
||||
psd_context=self.psd_context,
|
||||
statistics_mode=self.statistics_mode, )
|
||||
return xs.transpose(2, 1, 0)
|
Loading…
Reference in new issue