commit
41526ca1b8
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,262 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
|
||||||
|
from deepspeech.models.lm_interface import LMInterface
|
||||||
|
from deepspeech.modules.encoder import TransformerEncoder
|
||||||
|
from deepspeech.modules.mask import subsequent_mask
|
||||||
|
from deepspeech.utils.log import Log
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
|
||||||
|
def __init__(self,
|
||||||
|
n_vocab: int,
|
||||||
|
pos_enc: str=None,
|
||||||
|
embed_unit: int=128,
|
||||||
|
att_unit: int=256,
|
||||||
|
head: int=2,
|
||||||
|
unit: int=1024,
|
||||||
|
layer: int=4,
|
||||||
|
dropout_rate: float=0.5,
|
||||||
|
emb_dropout_rate: float=0.0,
|
||||||
|
att_dropout_rate: float=0.0,
|
||||||
|
tie_weights: bool=False,
|
||||||
|
**kwargs):
|
||||||
|
nn.Layer.__init__(self)
|
||||||
|
|
||||||
|
if pos_enc == "sinusoidal":
|
||||||
|
pos_enc_layer_type = "abs_pos"
|
||||||
|
elif pos_enc is None:
|
||||||
|
pos_enc_layer_type = "no_pos"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown pos-enc option: {pos_enc}")
|
||||||
|
|
||||||
|
self.embed = nn.Embedding(n_vocab, embed_unit)
|
||||||
|
|
||||||
|
if emb_dropout_rate == 0.0:
|
||||||
|
self.embed_drop = None
|
||||||
|
else:
|
||||||
|
self.embed_drop = nn.Dropout(emb_dropout_rate)
|
||||||
|
|
||||||
|
self.encoder = TransformerEncoder(
|
||||||
|
input_size=embed_unit,
|
||||||
|
output_size=att_unit,
|
||||||
|
attention_heads=head,
|
||||||
|
linear_units=unit,
|
||||||
|
num_blocks=layer,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
attention_dropout_rate=att_dropout_rate,
|
||||||
|
input_layer="linear",
|
||||||
|
pos_enc_layer_type=pos_enc_layer_type,
|
||||||
|
concat_after=False,
|
||||||
|
static_chunk_size=1,
|
||||||
|
use_dynamic_chunk=False,
|
||||||
|
use_dynamic_left_chunk=False)
|
||||||
|
|
||||||
|
self.decoder = nn.Linear(att_unit, n_vocab)
|
||||||
|
|
||||||
|
logger.info("Tie weights set to {}".format(tie_weights))
|
||||||
|
logger.info("Dropout set to {}".format(dropout_rate))
|
||||||
|
logger.info("Emb Dropout set to {}".format(emb_dropout_rate))
|
||||||
|
logger.info("Att Dropout set to {}".format(att_dropout_rate))
|
||||||
|
|
||||||
|
if tie_weights:
|
||||||
|
assert (
|
||||||
|
att_unit == embed_unit
|
||||||
|
), "Tie Weights: True need embedding and final dimensions to match"
|
||||||
|
self.decoder.weight = self.embed.weight
|
||||||
|
|
||||||
|
def _target_mask(self, ys_in_pad):
|
||||||
|
ys_mask = ys_in_pad != 0
|
||||||
|
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
|
||||||
|
return ys_mask.unsqueeze(-2) & m
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor, t: paddle.Tensor
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Compute LM loss value from buffer sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): Input ids. (batch, len)
|
||||||
|
t (paddle.Tensor): Target ids. (batch, len)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of
|
||||||
|
loss to backward (scalar),
|
||||||
|
negative log-likelihood of t: -log p(t) (scalar) and
|
||||||
|
the number of elements in x (scalar)
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The last two return values are used
|
||||||
|
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
|
||||||
|
|
||||||
|
"""
|
||||||
|
xm = x != 0
|
||||||
|
xlen = xm.sum(axis=1)
|
||||||
|
if self.embed_drop is not None:
|
||||||
|
emb = self.embed_drop(self.embed(x))
|
||||||
|
else:
|
||||||
|
emb = self.embed(x)
|
||||||
|
h, _ = self.encoder(emb, xlen)
|
||||||
|
y = self.decoder(h)
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
||||||
|
mask = xm.to(dtype=loss.dtype)
|
||||||
|
logp = loss * mask.view(-1)
|
||||||
|
logp = logp.sum()
|
||||||
|
count = mask.sum()
|
||||||
|
return logp / count, logp, count
|
||||||
|
|
||||||
|
# beam search API (see ScorerInterface)
|
||||||
|
def score(self, y: paddle.Tensor, state: Any,
|
||||||
|
x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
|
||||||
|
"""Score new token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
|
||||||
|
state: Scorer state for prefix tokens
|
||||||
|
x (paddle.Tensor): encoder feature that generates ys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, Any]: Tuple of
|
||||||
|
paddle.float32 scores for next token (n_vocab)
|
||||||
|
and next state for ys
|
||||||
|
|
||||||
|
"""
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
|
||||||
|
if self.embed_drop is not None:
|
||||||
|
emb = self.embed_drop(self.embed(y))
|
||||||
|
else:
|
||||||
|
emb = self.embed(y)
|
||||||
|
|
||||||
|
h, _, cache = self.encoder.forward_one_step(
|
||||||
|
emb, self._target_mask(y), cache=state)
|
||||||
|
h = self.decoder(h[:, -1])
|
||||||
|
logp = F.log_softmax(h).squeeze(0)
|
||||||
|
return logp, cache
|
||||||
|
|
||||||
|
# batch beam search API (see BatchScorerInterface)
|
||||||
|
def batch_score(self,
|
||||||
|
ys: paddle.Tensor,
|
||||||
|
states: List[Any],
|
||||||
|
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
|
||||||
|
"""Score new token batch (required).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
|
||||||
|
states (List[Any]): Scorer states for prefix tokens.
|
||||||
|
xs (paddle.Tensor):
|
||||||
|
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||||
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||||
|
and next state list for ys.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# merge states
|
||||||
|
n_batch = len(ys)
|
||||||
|
n_layers = len(self.encoder.encoders)
|
||||||
|
if states[0] is None:
|
||||||
|
batch_state = None
|
||||||
|
else:
|
||||||
|
# transpose state of [batch, layer] into [layer, batch]
|
||||||
|
batch_state = [
|
||||||
|
paddle.stack([states[b][i] for b in range(n_batch)])
|
||||||
|
for i in range(n_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.embed_drop is not None:
|
||||||
|
emb = self.embed_drop(self.embed(ys))
|
||||||
|
else:
|
||||||
|
emb = self.embed(ys)
|
||||||
|
|
||||||
|
# batch decoding
|
||||||
|
h, _, states = self.encoder.forward_one_step(
|
||||||
|
emb, self._target_mask(ys), cache=batch_state)
|
||||||
|
h = self.decoder(h[:, -1])
|
||||||
|
logp = F.log_softmax(h)
|
||||||
|
|
||||||
|
# transpose state of [layer, batch] into [batch, layer]
|
||||||
|
state_list = [[states[i][b] for i in range(n_layers)]
|
||||||
|
for b in range(n_batch)]
|
||||||
|
return logp, state_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tlm = TransformerLM(
|
||||||
|
n_vocab=5002,
|
||||||
|
pos_enc=None,
|
||||||
|
embed_unit=128,
|
||||||
|
att_unit=512,
|
||||||
|
head=8,
|
||||||
|
unit=2048,
|
||||||
|
layer=16,
|
||||||
|
dropout_rate=0.5, )
|
||||||
|
|
||||||
|
# n_vocab: int,
|
||||||
|
# pos_enc: str=None,
|
||||||
|
# embed_unit: int=128,
|
||||||
|
# att_unit: int=256,
|
||||||
|
# head: int=2,
|
||||||
|
# unit: int=1024,
|
||||||
|
# layer: int=4,
|
||||||
|
# dropout_rate: float=0.5,
|
||||||
|
# emb_dropout_rate: float = 0.0,
|
||||||
|
# att_dropout_rate: float = 0.0,
|
||||||
|
# tie_weights: bool = False,):
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
model_dict = paddle.load("transformerLM.pdparams")
|
||||||
|
tlm.set_state_dict(model_dict)
|
||||||
|
|
||||||
|
tlm.eval()
|
||||||
|
#Test the score
|
||||||
|
input2 = np.array([5])
|
||||||
|
input2 = paddle.to_tensor(input2)
|
||||||
|
state = None
|
||||||
|
output, state = tlm.score(input2, state, None)
|
||||||
|
|
||||||
|
input3 = np.array([5, 10])
|
||||||
|
input3 = paddle.to_tensor(input3)
|
||||||
|
output, state = tlm.score(input3, state, None)
|
||||||
|
|
||||||
|
input4 = np.array([5, 10, 0])
|
||||||
|
input4 = paddle.to_tensor(input4)
|
||||||
|
output, state = tlm.score(input4, state, None)
|
||||||
|
print("output", output)
|
||||||
|
"""
|
||||||
|
#Test the batch score
|
||||||
|
batch_size = 2
|
||||||
|
inp2 = np.array([[5], [10]])
|
||||||
|
inp2 = paddle.to_tensor(inp2)
|
||||||
|
output, states = tlm.batch_score(
|
||||||
|
inp2, [(None,None,0)] * batch_size)
|
||||||
|
inp3 = np.array([[100], [30]])
|
||||||
|
inp3 = paddle.to_tensor(inp3)
|
||||||
|
output, states = tlm.batch_score(
|
||||||
|
inp3, states)
|
||||||
|
print("output", output)
|
||||||
|
#print("cache", cache)
|
||||||
|
#np.save("output_pd.npy", output)
|
||||||
|
"""
|
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Language model interface."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
|
||||||
|
from deepspeech.utils.dynamic_import import dynamic_import
|
||||||
|
|
||||||
|
|
||||||
|
class LMInterface(ScorerInterface):
|
||||||
|
"""LM Interface model implementation."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_arguments(parser):
|
||||||
|
"""Add arguments to command line argument parser."""
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, n_vocab: int, **kwargs):
|
||||||
|
"""Initialize this class with python-level args.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idim (int): The number of vocabulary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LMinterface: A new instance of LMInterface.
|
||||||
|
|
||||||
|
"""
|
||||||
|
args = argparse.Namespace(**kwargs)
|
||||||
|
return cls(n_vocab, args)
|
||||||
|
|
||||||
|
def forward(self, x, t):
|
||||||
|
"""Compute LM loss value from buffer sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input ids. (batch, len)
|
||||||
|
t (torch.Tensor): Target ids. (batch, len)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
|
||||||
|
loss to backward (scalar),
|
||||||
|
negative log-likelihood of t: -log p(t) (scalar) and
|
||||||
|
the number of elements in x (scalar)
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The last two return values are used
|
||||||
|
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("forward method is not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
predefined_lms = {
|
||||||
|
"transformer": "deepspeech.models.lm.transformer:TransformerLM",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_import_lm(module):
|
||||||
|
"""Import LM class dynamically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (str): module_name:class_name or alias in `predefined_lms`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
type: LM class
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_class = dynamic_import(module, predefined_lms)
|
||||||
|
assert issubclass(model_class,
|
||||||
|
LMInterface), f"{module} does not implement LMInterface"
|
||||||
|
return model_class
|
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""ST Interface module."""
|
||||||
|
from .asr_interface import ASRInterface
|
||||||
|
from deepspeech.utils.dynamic_import import dynamic_import
|
||||||
|
|
||||||
|
|
||||||
|
class STInterface(ASRInterface):
|
||||||
|
"""ST Interface model implementation.
|
||||||
|
|
||||||
|
NOTE: This class is inherited from ASRInterface to enable joint translation
|
||||||
|
and recognition when performing multi-task learning with the ASR task.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def translate(self,
|
||||||
|
x,
|
||||||
|
trans_args,
|
||||||
|
char_list=None,
|
||||||
|
rnnlm=None,
|
||||||
|
ensemble_models=[]):
|
||||||
|
"""Recognize x for evaluation.
|
||||||
|
|
||||||
|
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
|
||||||
|
:param namespace trans_args: argment namespace contraining options
|
||||||
|
:param list char_list: list of characters
|
||||||
|
:param paddle.nn.Layer rnnlm: language model module
|
||||||
|
:return: N-best decoding results
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("translate method is not implemented")
|
||||||
|
|
||||||
|
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
|
||||||
|
"""Beam search implementation for batch.
|
||||||
|
|
||||||
|
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
|
||||||
|
:param namespace trans_args: argument namespace containing options
|
||||||
|
:param list char_list: list of characters
|
||||||
|
:param paddle.nn.Layer rnnlm: language model module
|
||||||
|
:return: N-best decoding results
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Batch decoding is not supported yet.")
|
||||||
|
|
||||||
|
|
||||||
|
predefined_st = {
|
||||||
|
"transformer": "deepspeech.models.u2_st:U2STModel",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_import_st(module):
|
||||||
|
"""Import ST models dynamically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (str): module_name:class_name or alias in `predefined_st`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
type: ST class
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_class = dynamic_import(module, predefined_st)
|
||||||
|
assert issubclass(model_class,
|
||||||
|
STInterface), f"{module} does not implement STInterface"
|
||||||
|
return model_class
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .u2_st import U2STInferModel
|
||||||
|
from .u2_st import U2STModel
|
@ -0,0 +1,418 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from . import extension
|
||||||
|
|
||||||
|
|
||||||
|
class PlotAttentionReport(extension.Extension):
|
||||||
|
"""Plot attention reporter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
|
||||||
|
Function of attention visualization.
|
||||||
|
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
||||||
|
outdir (str): Directory to save figures.
|
||||||
|
converter (espnet.asr.*_backend.asr.CustomConverter):
|
||||||
|
Function to convert data.
|
||||||
|
device (int | torch.device): Device.
|
||||||
|
reverse (bool): If True, input and output length are reversed.
|
||||||
|
ikey (str): Key to access input
|
||||||
|
(for ASR/ST ikey="input", for MT ikey="output".)
|
||||||
|
iaxis (int): Dimension to access input
|
||||||
|
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
||||||
|
okey (str): Key to access output
|
||||||
|
(for ASR/ST okey="input", MT okay="output".)
|
||||||
|
oaxis (int): Dimension to access output
|
||||||
|
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
||||||
|
subsampling_factor (int): subsampling factor in encoder
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
att_vis_fn,
|
||||||
|
data,
|
||||||
|
outdir,
|
||||||
|
converter,
|
||||||
|
transform,
|
||||||
|
device,
|
||||||
|
reverse=False,
|
||||||
|
ikey="input",
|
||||||
|
iaxis=0,
|
||||||
|
okey="output",
|
||||||
|
oaxis=0,
|
||||||
|
subsampling_factor=1, ):
|
||||||
|
self.att_vis_fn = att_vis_fn
|
||||||
|
self.data = copy.deepcopy(data)
|
||||||
|
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
||||||
|
# key is utterance ID
|
||||||
|
self.outdir = outdir
|
||||||
|
self.converter = converter
|
||||||
|
self.transform = transform
|
||||||
|
self.device = device
|
||||||
|
self.reverse = reverse
|
||||||
|
self.ikey = ikey
|
||||||
|
self.iaxis = iaxis
|
||||||
|
self.okey = okey
|
||||||
|
self.oaxis = oaxis
|
||||||
|
self.factor = subsampling_factor
|
||||||
|
if not os.path.exists(self.outdir):
|
||||||
|
os.makedirs(self.outdir)
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
"""Plot and save image file of att_ws matrix."""
|
||||||
|
att_ws, uttid_list = self.get_attention_weights()
|
||||||
|
if isinstance(att_ws, list): # multi-encoder case
|
||||||
|
num_encs = len(att_ws) - 1
|
||||||
|
# atts
|
||||||
|
for i in range(num_encs):
|
||||||
|
for idx, att_w in enumerate(att_ws[i]):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
|
||||||
|
self.outdir, uttid_list[idx], i + 1, )
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], i + 1, )
|
||||||
|
np.save(np_filename.format(trainer), att_w)
|
||||||
|
self._plot_and_save_attention(att_w,
|
||||||
|
filename.format(trainer))
|
||||||
|
# han
|
||||||
|
for idx, att_w in enumerate(att_ws[num_encs]):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
|
||||||
|
self.outdir, uttid_list[idx], )
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], )
|
||||||
|
np.save(np_filename.format(trainer), att_w)
|
||||||
|
self._plot_and_save_attention(
|
||||||
|
att_w, filename.format(trainer), han_mode=True)
|
||||||
|
else:
|
||||||
|
for idx, att_w in enumerate(att_ws):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
|
||||||
|
uttid_list[idx], )
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], )
|
||||||
|
np.save(np_filename.format(trainer), att_w)
|
||||||
|
self._plot_and_save_attention(att_w, filename.format(trainer))
|
||||||
|
|
||||||
|
def log_attentions(self, logger, step):
|
||||||
|
"""Add image files of att_ws matrix to the tensorboard."""
|
||||||
|
att_ws, uttid_list = self.get_attention_weights()
|
||||||
|
if isinstance(att_ws, list): # multi-encoder case
|
||||||
|
num_encs = len(att_ws) - 1
|
||||||
|
# atts
|
||||||
|
for i in range(num_encs):
|
||||||
|
for idx, att_w in enumerate(att_ws[i]):
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
plot = self.draw_attention_plot(att_w)
|
||||||
|
logger.add_figure(
|
||||||
|
"%s_att%d" % (uttid_list[idx], i + 1),
|
||||||
|
plot.gcf(),
|
||||||
|
step, )
|
||||||
|
# han
|
||||||
|
for idx, att_w in enumerate(att_ws[num_encs]):
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
plot = self.draw_han_plot(att_w)
|
||||||
|
logger.add_figure(
|
||||||
|
"%s_han" % (uttid_list[idx]),
|
||||||
|
plot.gcf(),
|
||||||
|
step, )
|
||||||
|
else:
|
||||||
|
for idx, att_w in enumerate(att_ws):
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
plot = self.draw_attention_plot(att_w)
|
||||||
|
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
||||||
|
|
||||||
|
def get_attention_weights(self):
|
||||||
|
"""Return attention weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: attention weights. float. Its shape would be
|
||||||
|
differ from backend.
|
||||||
|
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
|
||||||
|
other case => (B, Lmax, Tmax).
|
||||||
|
* chainer-> (B, Lmax, Tmax)
|
||||||
|
|
||||||
|
"""
|
||||||
|
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
||||||
|
batch = self.converter([return_batch], self.device)
|
||||||
|
if isinstance(batch, tuple):
|
||||||
|
att_ws = self.att_vis_fn(*batch)
|
||||||
|
else:
|
||||||
|
att_ws = self.att_vis_fn(**batch)
|
||||||
|
return att_ws, uttid_list
|
||||||
|
|
||||||
|
def trim_attention_weight(self, uttid, att_w):
|
||||||
|
"""Transform attention matrix with regard to self.reverse."""
|
||||||
|
if self.reverse:
|
||||||
|
enc_key, enc_axis = self.okey, self.oaxis
|
||||||
|
dec_key, dec_axis = self.ikey, self.iaxis
|
||||||
|
else:
|
||||||
|
enc_key, enc_axis = self.ikey, self.iaxis
|
||||||
|
dec_key, dec_axis = self.okey, self.oaxis
|
||||||
|
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
|
||||||
|
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
|
||||||
|
if self.factor > 1:
|
||||||
|
enc_len //= self.factor
|
||||||
|
if len(att_w.shape) == 3:
|
||||||
|
att_w = att_w[:, :dec_len, :enc_len]
|
||||||
|
else:
|
||||||
|
att_w = att_w[:dec_len, :enc_len]
|
||||||
|
return att_w
|
||||||
|
|
||||||
|
def draw_attention_plot(self, att_w):
|
||||||
|
"""Plot the att_w matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib.pyplot: pyplot object with attention matrix image.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
att_w = att_w.astype(np.float32)
|
||||||
|
if len(att_w.shape) == 3:
|
||||||
|
for h, aw in enumerate(att_w, 1):
|
||||||
|
plt.subplot(1, len(att_w), h)
|
||||||
|
plt.imshow(aw, aspect="auto")
|
||||||
|
plt.xlabel("Encoder Index")
|
||||||
|
plt.ylabel("Decoder Index")
|
||||||
|
else:
|
||||||
|
plt.imshow(att_w, aspect="auto")
|
||||||
|
plt.xlabel("Encoder Index")
|
||||||
|
plt.ylabel("Decoder Index")
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def draw_han_plot(self, att_w):
|
||||||
|
"""Plot the att_w matrix for hierarchical attention.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib.pyplot: pyplot object with attention matrix image.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
if len(att_w.shape) == 3:
|
||||||
|
for h, aw in enumerate(att_w, 1):
|
||||||
|
legends = []
|
||||||
|
plt.subplot(1, len(att_w), h)
|
||||||
|
for i in range(aw.shape[1]):
|
||||||
|
plt.plot(aw[:, i])
|
||||||
|
legends.append("Att{}".format(i))
|
||||||
|
plt.ylim([0, 1.0])
|
||||||
|
plt.xlim([0, aw.shape[0]])
|
||||||
|
plt.grid(True)
|
||||||
|
plt.ylabel("Attention Weight")
|
||||||
|
plt.xlabel("Decoder Index")
|
||||||
|
plt.legend(legends)
|
||||||
|
else:
|
||||||
|
legends = []
|
||||||
|
for i in range(att_w.shape[1]):
|
||||||
|
plt.plot(att_w[:, i])
|
||||||
|
legends.append("Att{}".format(i))
|
||||||
|
plt.ylim([0, 1.0])
|
||||||
|
plt.xlim([0, att_w.shape[0]])
|
||||||
|
plt.grid(True)
|
||||||
|
plt.ylabel("Attention Weight")
|
||||||
|
plt.xlabel("Decoder Index")
|
||||||
|
plt.legend(legends)
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
|
||||||
|
if han_mode:
|
||||||
|
plt = self.draw_han_plot(att_w)
|
||||||
|
else:
|
||||||
|
plt = self.draw_attention_plot(att_w)
|
||||||
|
plt.savefig(filename)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
class PlotCTCReport(extension.Extension):
|
||||||
|
"""Plot CTC reporter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
|
||||||
|
Function of CTC visualization.
|
||||||
|
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
||||||
|
outdir (str): Directory to save figures.
|
||||||
|
converter (espnet.asr.*_backend.asr.CustomConverter):
|
||||||
|
Function to convert data.
|
||||||
|
device (int | torch.device): Device.
|
||||||
|
reverse (bool): If True, input and output length are reversed.
|
||||||
|
ikey (str): Key to access input
|
||||||
|
(for ASR/ST ikey="input", for MT ikey="output".)
|
||||||
|
iaxis (int): Dimension to access input
|
||||||
|
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
||||||
|
okey (str): Key to access output
|
||||||
|
(for ASR/ST okey="input", MT okay="output".)
|
||||||
|
oaxis (int): Dimension to access output
|
||||||
|
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
||||||
|
subsampling_factor (int): subsampling factor in encoder
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ctc_vis_fn,
|
||||||
|
data,
|
||||||
|
outdir,
|
||||||
|
converter,
|
||||||
|
transform,
|
||||||
|
device,
|
||||||
|
reverse=False,
|
||||||
|
ikey="input",
|
||||||
|
iaxis=0,
|
||||||
|
okey="output",
|
||||||
|
oaxis=0,
|
||||||
|
subsampling_factor=1, ):
|
||||||
|
self.ctc_vis_fn = ctc_vis_fn
|
||||||
|
self.data = copy.deepcopy(data)
|
||||||
|
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
||||||
|
# key is utterance ID
|
||||||
|
self.outdir = outdir
|
||||||
|
self.converter = converter
|
||||||
|
self.transform = transform
|
||||||
|
self.device = device
|
||||||
|
self.reverse = reverse
|
||||||
|
self.ikey = ikey
|
||||||
|
self.iaxis = iaxis
|
||||||
|
self.okey = okey
|
||||||
|
self.oaxis = oaxis
|
||||||
|
self.factor = subsampling_factor
|
||||||
|
if not os.path.exists(self.outdir):
|
||||||
|
os.makedirs(self.outdir)
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
"""Plot and save image file of ctc prob."""
|
||||||
|
ctc_probs, uttid_list = self.get_ctc_probs()
|
||||||
|
if isinstance(ctc_probs, list): # multi-encoder case
|
||||||
|
num_encs = len(ctc_probs) - 1
|
||||||
|
for i in range(num_encs):
|
||||||
|
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
|
||||||
|
self.outdir, uttid_list[idx], i + 1, )
|
||||||
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], i + 1, )
|
||||||
|
np.save(np_filename.format(trainer), ctc_prob)
|
||||||
|
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
||||||
|
else:
|
||||||
|
for idx, ctc_prob in enumerate(ctc_probs):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
|
||||||
|
uttid_list[idx], )
|
||||||
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], )
|
||||||
|
np.save(np_filename.format(trainer), ctc_prob)
|
||||||
|
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
||||||
|
|
||||||
|
def log_ctc_probs(self, logger, step):
|
||||||
|
"""Add image files of ctc probs to the tensorboard."""
|
||||||
|
ctc_probs, uttid_list = self.get_ctc_probs()
|
||||||
|
if isinstance(ctc_probs, list): # multi-encoder case
|
||||||
|
num_encs = len(ctc_probs) - 1
|
||||||
|
for i in range(num_encs):
|
||||||
|
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
||||||
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||||
|
plot = self.draw_ctc_plot(ctc_prob)
|
||||||
|
logger.add_figure(
|
||||||
|
"%s_ctc%d" % (uttid_list[idx], i + 1),
|
||||||
|
plot.gcf(),
|
||||||
|
step, )
|
||||||
|
else:
|
||||||
|
for idx, ctc_prob in enumerate(ctc_probs):
|
||||||
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||||
|
plot = self.draw_ctc_plot(ctc_prob)
|
||||||
|
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
||||||
|
|
||||||
|
def get_ctc_probs(self):
|
||||||
|
"""Return CTC probs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: CTC probs. float. Its shape would be
|
||||||
|
differ from backend. (B, Tmax, vocab).
|
||||||
|
|
||||||
|
"""
|
||||||
|
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
||||||
|
batch = self.converter([return_batch], self.device)
|
||||||
|
if isinstance(batch, tuple):
|
||||||
|
probs = self.ctc_vis_fn(*batch)
|
||||||
|
else:
|
||||||
|
probs = self.ctc_vis_fn(**batch)
|
||||||
|
return probs, uttid_list
|
||||||
|
|
||||||
|
def trim_ctc_prob(self, uttid, prob):
|
||||||
|
"""Trim CTC posteriors accoding to input lengths."""
|
||||||
|
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
|
||||||
|
if self.factor > 1:
|
||||||
|
enc_len //= self.factor
|
||||||
|
prob = prob[:enc_len]
|
||||||
|
return prob
|
||||||
|
|
||||||
|
def draw_ctc_plot(self, ctc_prob):
|
||||||
|
"""Plot the ctc_prob matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib.pyplot: pyplot object with CTC prob matrix image.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
ctc_prob = ctc_prob.astype(np.float32)
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
topk_ids = np.argsort(ctc_prob, axis=1)
|
||||||
|
n_frames, vocab = ctc_prob.shape
|
||||||
|
times_probs = np.arange(n_frames)
|
||||||
|
|
||||||
|
plt.figure(figsize=(20, 8))
|
||||||
|
|
||||||
|
# NOTE: index 0 is reserved for blank
|
||||||
|
for idx in set(topk_ids.reshape(-1).tolist()):
|
||||||
|
if idx == 0:
|
||||||
|
plt.plot(
|
||||||
|
times_probs,
|
||||||
|
ctc_prob[:, 0],
|
||||||
|
":",
|
||||||
|
label="<blank>",
|
||||||
|
color="grey")
|
||||||
|
else:
|
||||||
|
plt.plot(times_probs, ctc_prob[:, idx])
|
||||||
|
plt.xlabel(u"Input [frame]", fontsize=12)
|
||||||
|
plt.ylabel("Posteriors", fontsize=12)
|
||||||
|
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
|
||||||
|
plt.yticks(list(range(0, 2, 1)))
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def _plot_and_save_ctc(self, ctc_prob, filename):
|
||||||
|
plt = self.draw_ctc_plot(ctc_prob)
|
||||||
|
plt.savefig(filename)
|
||||||
|
plt.close()
|
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from ..reporter import DictSummary
|
||||||
|
from .utils import get_trigger
|
||||||
|
|
||||||
|
|
||||||
|
class CompareValueTrigger():
|
||||||
|
"""Trigger invoked when key value getting bigger or lower than before.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str) : Key of value.
|
||||||
|
compare_fn ((float, float) -> bool) : Function to compare the values.
|
||||||
|
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
|
||||||
|
self._key = key
|
||||||
|
self._best_value = None
|
||||||
|
self._interval_trigger = get_trigger(trigger)
|
||||||
|
self._init_summary()
|
||||||
|
self._compare_fn = compare_fn
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
"""Get value related to the key and compare with current value."""
|
||||||
|
observation = trainer.observation
|
||||||
|
summary = self._summary
|
||||||
|
key = self._key
|
||||||
|
if key in observation:
|
||||||
|
summary.add({key: observation[key]})
|
||||||
|
|
||||||
|
if not self._interval_trigger(trainer):
|
||||||
|
return False
|
||||||
|
|
||||||
|
stats = summary.compute_mean()
|
||||||
|
value = float(stats[key]) # copy to CPU
|
||||||
|
self._init_summary()
|
||||||
|
|
||||||
|
if self._best_value is None:
|
||||||
|
# initialize best value
|
||||||
|
self._best_value = value
|
||||||
|
return False
|
||||||
|
elif self._compare_fn(self._best_value, value):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._best_value = value
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _init_summary(self):
|
||||||
|
self._summary = DictSummary()
|
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .interval_trigger import IntervalTrigger
|
||||||
|
|
||||||
|
|
||||||
|
def never_fail_trigger(trainer):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_trigger(trigger):
|
||||||
|
if trigger is None:
|
||||||
|
return never_fail_trigger
|
||||||
|
if callable(trigger):
|
||||||
|
return trigger
|
||||||
|
else:
|
||||||
|
trigger = IntervalTrigger(*trigger)
|
||||||
|
return trigger
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def delta(feat, window):
|
||||||
|
assert window > 0
|
||||||
|
delta_feat = np.zeros_like(feat)
|
||||||
|
for i in range(1, window + 1):
|
||||||
|
delta_feat[:-i] += i * feat[i:]
|
||||||
|
delta_feat[i:] += -i * feat[:-i]
|
||||||
|
delta_feat[-i:] += i * feat[-1]
|
||||||
|
delta_feat[:i] += -i * feat[0]
|
||||||
|
delta_feat /= 2 * sum(i**2 for i in range(1, window + 1))
|
||||||
|
return delta_feat
|
||||||
|
|
||||||
|
|
||||||
|
def add_deltas(x, window=2, order=2):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): speech feat, (T, D).
|
||||||
|
|
||||||
|
Return:
|
||||||
|
np.ndarray: (T, (1+order)*D)
|
||||||
|
"""
|
||||||
|
feats = [x]
|
||||||
|
for _ in range(order):
|
||||||
|
feats.append(delta(feats[-1], window))
|
||||||
|
return np.concatenate(feats, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class AddDeltas():
|
||||||
|
def __init__(self, window=2, order=2):
|
||||||
|
self.window = window
|
||||||
|
self.order = order
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}(window={window}, order={order}".format(
|
||||||
|
name=self.__class__.__name__, window=self.window, order=self.order)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return add_deltas(x, window=self.window, order=self.order)
|
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelSelector():
|
||||||
|
"""Select 1ch from multi-channel signal"""
|
||||||
|
|
||||||
|
def __init__(self, train_channel="random", eval_channel=0, axis=1):
|
||||||
|
self.train_channel = train_channel
|
||||||
|
self.eval_channel = eval_channel
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(train_channel={train_channel}, "
|
||||||
|
"eval_channel={eval_channel}, axis={axis})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
train_channel=self.train_channel,
|
||||||
|
eval_channel=self.eval_channel,
|
||||||
|
axis=self.axis, ))
|
||||||
|
|
||||||
|
def __call__(self, x, train=True):
|
||||||
|
# Assuming x: [Time, Channel] by default
|
||||||
|
|
||||||
|
if x.ndim <= self.axis:
|
||||||
|
# If the dimension is insufficient, then unsqueeze
|
||||||
|
# (e.g [Time] -> [Time, 1])
|
||||||
|
ind = tuple(
|
||||||
|
slice(None) if i < x.ndim else None
|
||||||
|
for i in range(self.axis + 1))
|
||||||
|
x = x[ind]
|
||||||
|
|
||||||
|
if train:
|
||||||
|
channel = self.train_channel
|
||||||
|
else:
|
||||||
|
channel = self.eval_channel
|
||||||
|
|
||||||
|
if channel == "random":
|
||||||
|
ch = numpy.random.randint(0, x.shape[self.axis])
|
||||||
|
else:
|
||||||
|
ch = channel
|
||||||
|
|
||||||
|
ind = tuple(
|
||||||
|
slice(None) if i != self.axis else ch for i in range(x.ndim))
|
||||||
|
return x[ind]
|
@ -0,0 +1,158 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import io
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import kaldiio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class CMVN():
|
||||||
|
"Apply Global/Spk CMVN/iverserCMVN."
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stats,
|
||||||
|
norm_means=True,
|
||||||
|
norm_vars=False,
|
||||||
|
filetype="mat",
|
||||||
|
utt2spk=None,
|
||||||
|
spk2utt=None,
|
||||||
|
reverse=False,
|
||||||
|
std_floor=1.0e-20, ):
|
||||||
|
self.stats_file = stats
|
||||||
|
self.norm_means = norm_means
|
||||||
|
self.norm_vars = norm_vars
|
||||||
|
self.reverse = reverse
|
||||||
|
|
||||||
|
if isinstance(stats, dict):
|
||||||
|
stats_dict = dict(stats)
|
||||||
|
else:
|
||||||
|
# Use for global CMVN
|
||||||
|
if filetype == "mat":
|
||||||
|
stats_dict = {None: kaldiio.load_mat(stats)}
|
||||||
|
# Use for global CMVN
|
||||||
|
elif filetype == "npy":
|
||||||
|
stats_dict = {None: np.load(stats)}
|
||||||
|
# Use for speaker CMVN
|
||||||
|
elif filetype == "ark":
|
||||||
|
self.accept_uttid = True
|
||||||
|
stats_dict = dict(kaldiio.load_ark(stats))
|
||||||
|
# Use for speaker CMVN
|
||||||
|
elif filetype == "hdf5":
|
||||||
|
self.accept_uttid = True
|
||||||
|
stats_dict = h5py.File(stats)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not supporting filetype={}".format(filetype))
|
||||||
|
|
||||||
|
if utt2spk is not None:
|
||||||
|
self.utt2spk = {}
|
||||||
|
with io.open(utt2spk, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, spk = line.rstrip().split(None, 1)
|
||||||
|
self.utt2spk[utt] = spk
|
||||||
|
elif spk2utt is not None:
|
||||||
|
self.utt2spk = {}
|
||||||
|
with io.open(spk2utt, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
spk, utts = line.rstrip().split(None, 1)
|
||||||
|
for utt in utts.split():
|
||||||
|
self.utt2spk[utt] = spk
|
||||||
|
else:
|
||||||
|
self.utt2spk = None
|
||||||
|
|
||||||
|
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
|
||||||
|
# and the first vector contains the sum of feats and the second is
|
||||||
|
# the sum of squares. The last value of the first, i.e. stats[0,-1],
|
||||||
|
# is the number of samples for this statistics.
|
||||||
|
self.bias = {}
|
||||||
|
self.scale = {}
|
||||||
|
for spk, stats in stats_dict.items():
|
||||||
|
assert len(stats) == 2, stats.shape
|
||||||
|
|
||||||
|
count = stats[0, -1]
|
||||||
|
|
||||||
|
# If the feature has two or more dimensions
|
||||||
|
if not (np.isscalar(count) or isinstance(count, (int, float))):
|
||||||
|
# The first is only used
|
||||||
|
count = count.flatten()[0]
|
||||||
|
|
||||||
|
mean = stats[0, :-1] / count
|
||||||
|
# V(x) = E(x^2) - (E(x))^2
|
||||||
|
var = stats[1, :-1] / count - mean * mean
|
||||||
|
std = np.maximum(np.sqrt(var), std_floor)
|
||||||
|
self.bias[spk] = -mean
|
||||||
|
self.scale[spk] = 1 / std
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(stats_file={stats_file}, "
|
||||||
|
"norm_means={norm_means}, norm_vars={norm_vars}, "
|
||||||
|
"reverse={reverse})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
stats_file=self.stats_file,
|
||||||
|
norm_means=self.norm_means,
|
||||||
|
norm_vars=self.norm_vars,
|
||||||
|
reverse=self.reverse, ))
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None):
|
||||||
|
if self.utt2spk is not None:
|
||||||
|
spk = self.utt2spk[uttid]
|
||||||
|
else:
|
||||||
|
spk = uttid
|
||||||
|
|
||||||
|
if not self.reverse:
|
||||||
|
# apply cmvn
|
||||||
|
if self.norm_means:
|
||||||
|
x = np.add(x, self.bias[spk])
|
||||||
|
if self.norm_vars:
|
||||||
|
x = np.multiply(x, self.scale[spk])
|
||||||
|
|
||||||
|
else:
|
||||||
|
# apply reverse cmvn
|
||||||
|
if self.norm_vars:
|
||||||
|
x = np.divide(x, self.scale[spk])
|
||||||
|
if self.norm_means:
|
||||||
|
x = np.subtract(x, self.bias[spk])
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UtteranceCMVN():
|
||||||
|
"Apply Utterance CMVN"
|
||||||
|
|
||||||
|
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
|
||||||
|
self.norm_means = norm_means
|
||||||
|
self.norm_vars = norm_vars
|
||||||
|
self.std_floor = std_floor
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
norm_means=self.norm_means,
|
||||||
|
norm_vars=self.norm_vars, )
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None):
|
||||||
|
# x: [Time, Dim]
|
||||||
|
square_sums = (x**2).sum(axis=0)
|
||||||
|
mean = x.mean(axis=0)
|
||||||
|
|
||||||
|
if self.norm_means:
|
||||||
|
x = np.subtract(x, mean)
|
||||||
|
|
||||||
|
if self.norm_vars:
|
||||||
|
var = square_sums / x.shape[0] - mean**2
|
||||||
|
std = np.maximum(np.sqrt(var), self.std_floor)
|
||||||
|
x = np.divide(x, std)
|
||||||
|
|
||||||
|
return x
|
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from deepspeech.transform.transform_interface import TransformInterface
|
||||||
|
from deepspeech.utils.check_kwargs import check_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class FuncTrans(TransformInterface):
|
||||||
|
"""Functional Transformation
|
||||||
|
|
||||||
|
WARNING:
|
||||||
|
Builtin or C/C++ functions may not work properly
|
||||||
|
because this class heavily depends on the `inspect` module.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
>>> def foo_bar(x, a=1, b=2):
|
||||||
|
... '''Foo bar
|
||||||
|
... :param x: input
|
||||||
|
... :param int a: default 1
|
||||||
|
... :param int b: default 2
|
||||||
|
... '''
|
||||||
|
... return x + a - b
|
||||||
|
|
||||||
|
|
||||||
|
>>> class FooBar(FuncTrans):
|
||||||
|
... _func = foo_bar
|
||||||
|
... __doc__ = foo_bar.__doc__
|
||||||
|
"""
|
||||||
|
|
||||||
|
_func = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
check_kwargs(self.func, kwargs)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.func(x, **self.kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser):
|
||||||
|
fname = cls._func.__name__.replace("_", "-")
|
||||||
|
group = parser.add_argument_group(fname + " transformation setting")
|
||||||
|
for k, v in cls.default_params().items():
|
||||||
|
# TODO(karita): get help and choices from docstring?
|
||||||
|
attr = k.replace("_", "-")
|
||||||
|
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@property
|
||||||
|
def func(self):
|
||||||
|
return type(self)._func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_params(cls):
|
||||||
|
try:
|
||||||
|
d = dict(inspect.signature(cls._func).parameters)
|
||||||
|
except ValueError:
|
||||||
|
d = dict()
|
||||||
|
return {
|
||||||
|
k: v.default
|
||||||
|
for k, v in d.items() if v.default != inspect.Parameter.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
params = self.default_params()
|
||||||
|
params.update(**self.kwargs)
|
||||||
|
ret = self.__class__.__name__ + "("
|
||||||
|
if len(params) == 0:
|
||||||
|
return ret + ")"
|
||||||
|
for k, v in params.items():
|
||||||
|
ret += "{}={}, ".format(k, v)
|
||||||
|
return ret[:-2] + ")"
|
@ -0,0 +1,350 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import librosa
|
||||||
|
import numpy
|
||||||
|
import scipy
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from deepspeech.io.reader import SoundHDF5File
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedPerturbation():
|
||||||
|
"""SpeedPerturbation
|
||||||
|
|
||||||
|
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||||
|
and sox-speed just to resample the input,
|
||||||
|
i.e pitch and tempo are changed both.
|
||||||
|
|
||||||
|
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||||
|
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This function is very slow because of resampling.
|
||||||
|
I recommmend to apply speed-perturb outside the training using sox.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lower=0.9,
|
||||||
|
upper=1.1,
|
||||||
|
utt2ratio=None,
|
||||||
|
keep_length=True,
|
||||||
|
res_type="kaiser_best",
|
||||||
|
seed=None, ):
|
||||||
|
self.res_type = res_type
|
||||||
|
self.keep_length = keep_length
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
self.utt2ratio = {}
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.lower = None
|
||||||
|
self.upper = None
|
||||||
|
self.accept_uttid = True
|
||||||
|
|
||||||
|
with open(utt2ratio, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, ratio = line.rstrip().split(None, 1)
|
||||||
|
ratio = float(ratio)
|
||||||
|
self.utt2ratio[utt] = ratio
|
||||||
|
else:
|
||||||
|
self.utt2ratio = None
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.lower,
|
||||||
|
self.upper,
|
||||||
|
self.keep_length,
|
||||||
|
self.res_type, )
|
||||||
|
else:
|
||||||
|
return "{}({}, res_type={})".format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.res_type)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
if self.accept_uttid:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
|
||||||
|
# Note1: resample requires the sampling-rate of input and output,
|
||||||
|
# but actually only the ratio is used.
|
||||||
|
y = librosa.resample(x, ratio, 1, res_type=self.res_type)
|
||||||
|
|
||||||
|
if self.keep_length:
|
||||||
|
diff = abs(len(x) - len(y))
|
||||||
|
if len(y) > len(x):
|
||||||
|
# Truncate noise
|
||||||
|
y = y[diff // 2:-((diff + 1) // 2)]
|
||||||
|
elif len(y) < len(x):
|
||||||
|
# Assume the time-axis is the first: (Time, Channel)
|
||||||
|
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||||
|
(0, 0) for _ in range(y.ndim - 1)
|
||||||
|
]
|
||||||
|
y = numpy.pad(
|
||||||
|
y, pad_width=pad_width, constant_values=0, mode="constant")
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BandpassPerturbation():
|
||||||
|
"""BandpassPerturbation
|
||||||
|
|
||||||
|
Randomly dropout along the frequency axis.
|
||||||
|
|
||||||
|
The original idea comes from the following:
|
||||||
|
"randomly-selected frequency band was cut off under the constraint of
|
||||||
|
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
|
||||||
|
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
|
||||||
|
everyday home environments using multiple microphone arrays;
|
||||||
|
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )):
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
# x_stft: (Time, Channel, Freq)
|
||||||
|
self.axes = axes
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}(lower={}, upper={})".format(self.__class__.__name__,
|
||||||
|
self.lower, self.upper)
|
||||||
|
|
||||||
|
def __call__(self, x_stft, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x_stft
|
||||||
|
|
||||||
|
if x_stft.ndim == 1:
|
||||||
|
raise RuntimeError("Input in time-freq domain: "
|
||||||
|
"(Time, Channel, Freq) or (Time, Freq)")
|
||||||
|
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
|
||||||
|
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
|
||||||
|
|
||||||
|
mask = self.state.randn(*shape) > ratio
|
||||||
|
x_stft *= mask
|
||||||
|
return x_stft
|
||||||
|
|
||||||
|
|
||||||
|
class VolumePerturbation():
|
||||||
|
def __init__(self,
|
||||||
|
lower=-1.6,
|
||||||
|
upper=1.6,
|
||||||
|
utt2ratio=None,
|
||||||
|
dbunit=True,
|
||||||
|
seed=None):
|
||||||
|
self.dbunit = dbunit
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio = {}
|
||||||
|
self.lower = None
|
||||||
|
self.upper = None
|
||||||
|
self.accept_uttid = True
|
||||||
|
|
||||||
|
with open(utt2ratio, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, ratio = line.rstrip().split(None, 1)
|
||||||
|
ratio = float(ratio)
|
||||||
|
self.utt2ratio[utt] = ratio
|
||||||
|
else:
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.utt2ratio = None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||||
|
self.__class__.__name__, self.lower, self.upper, self.dbunit)
|
||||||
|
else:
|
||||||
|
return '{}("{}", dbunit={})'.format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
if self.accept_uttid:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
if self.dbunit:
|
||||||
|
ratio = 10**(ratio / 20)
|
||||||
|
return x * ratio
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseInjection():
|
||||||
|
"""Add isotropic noise"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
utt2noise=None,
|
||||||
|
lower=-20,
|
||||||
|
upper=-5,
|
||||||
|
utt2ratio=None,
|
||||||
|
filetype="list",
|
||||||
|
dbunit=True,
|
||||||
|
seed=None, ):
|
||||||
|
self.utt2noise_file = utt2noise
|
||||||
|
self.utt2ratio_file = utt2ratio
|
||||||
|
self.filetype = filetype
|
||||||
|
self.dbunit = dbunit
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.state = numpy.random.RandomState(seed)
|
||||||
|
|
||||||
|
if utt2ratio is not None:
|
||||||
|
# Use the scheduled ratio for each utterances
|
||||||
|
self.utt2ratio = {}
|
||||||
|
with open(utt2noise, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, snr = line.rstrip().split(None, 1)
|
||||||
|
snr = float(snr)
|
||||||
|
self.utt2ratio[utt] = snr
|
||||||
|
else:
|
||||||
|
# The ratio is given on runtime randomly
|
||||||
|
self.utt2ratio = None
|
||||||
|
|
||||||
|
if utt2noise is not None:
|
||||||
|
self.utt2noise = {}
|
||||||
|
if filetype == "list":
|
||||||
|
with open(utt2noise, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, filename = line.rstrip().split(None, 1)
|
||||||
|
signal, rate = soundfile.read(filename, dtype="int16")
|
||||||
|
# Load all files in memory
|
||||||
|
self.utt2noise[utt] = (signal, rate)
|
||||||
|
|
||||||
|
elif filetype == "sound.hdf5":
|
||||||
|
self.utt2noise = SoundHDF5File(utt2noise, "r")
|
||||||
|
else:
|
||||||
|
raise ValueError(filetype)
|
||||||
|
else:
|
||||||
|
self.utt2noise = None
|
||||||
|
|
||||||
|
if utt2noise is not None and utt2ratio is not None:
|
||||||
|
if set(self.utt2ratio) != set(self.utt2noise):
|
||||||
|
raise RuntimeError("The uttids mismatch between {} and {}".
|
||||||
|
format(utt2ratio, utt2noise))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.utt2ratio is None:
|
||||||
|
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||||
|
self.__class__.__name__, self.lower, self.upper, self.dbunit)
|
||||||
|
else:
|
||||||
|
return '{}("{}", dbunit={})'.format(
|
||||||
|
self.__class__.__name__, self.utt2ratio_file, self.dbunit)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
# 1. Get ratio of noise to signal in sound pressure level
|
||||||
|
if uttid is not None and self.utt2ratio is not None:
|
||||||
|
ratio = self.utt2ratio[uttid]
|
||||||
|
else:
|
||||||
|
ratio = self.state.uniform(self.lower, self.upper)
|
||||||
|
|
||||||
|
if self.dbunit:
|
||||||
|
ratio = 10**(ratio / 20)
|
||||||
|
scale = ratio * numpy.sqrt((x**2).mean())
|
||||||
|
|
||||||
|
# 2. Get noise
|
||||||
|
if self.utt2noise is not None:
|
||||||
|
# Get noise from the external source
|
||||||
|
if uttid is not None:
|
||||||
|
noise, rate = self.utt2noise[uttid]
|
||||||
|
else:
|
||||||
|
# Randomly select the noise source
|
||||||
|
noise = self.state.choice(list(self.utt2noise.values()))
|
||||||
|
# Normalize the level
|
||||||
|
noise /= numpy.sqrt((noise**2).mean())
|
||||||
|
|
||||||
|
# Adjust the noise length
|
||||||
|
diff = abs(len(x) - len(noise))
|
||||||
|
offset = self.state.randint(0, diff)
|
||||||
|
if len(noise) > len(x):
|
||||||
|
# Truncate noise
|
||||||
|
noise = noise[offset:-(diff - offset)]
|
||||||
|
else:
|
||||||
|
noise = numpy.pad(
|
||||||
|
noise, pad_width=[offset, diff - offset], mode="wrap")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Generate white noise
|
||||||
|
noise = self.state.normal(0, 1, x.shape)
|
||||||
|
|
||||||
|
# 3. Add noise to signal
|
||||||
|
return x + noise * scale
|
||||||
|
|
||||||
|
|
||||||
|
class RIRConvolve():
|
||||||
|
def __init__(self, utt2rir, filetype="list"):
|
||||||
|
self.utt2rir_file = utt2rir
|
||||||
|
self.filetype = filetype
|
||||||
|
|
||||||
|
self.utt2rir = {}
|
||||||
|
if filetype == "list":
|
||||||
|
with open(utt2rir, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
utt, filename = line.rstrip().split(None, 1)
|
||||||
|
signal, rate = soundfile.read(filename, dtype="int16")
|
||||||
|
self.utt2rir[utt] = (signal, rate)
|
||||||
|
|
||||||
|
elif filetype == "sound.hdf5":
|
||||||
|
self.utt2rir = SoundHDF5File(utt2rir, "r")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(filetype)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
|
||||||
|
|
||||||
|
def __call__(self, x, uttid=None, train=True):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x.astype(numpy.float32)
|
||||||
|
|
||||||
|
if x.ndim != 1:
|
||||||
|
# Must be single channel
|
||||||
|
raise RuntimeError(
|
||||||
|
"Input x must be one dimensional array, but got {}".format(
|
||||||
|
x.shape))
|
||||||
|
|
||||||
|
rir, rate = self.utt2rir[uttid]
|
||||||
|
if rir.ndim == 2:
|
||||||
|
# FIXME(kamo): Use chainer.convolution_1d?
|
||||||
|
# return [Time, Channel]
|
||||||
|
return numpy.stack(
|
||||||
|
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
|
||||||
|
else:
|
||||||
|
return scipy.convolve(x, rir, mode="same")
|
@ -0,0 +1,210 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Spec Augment module for preprocessing i.e., data augmentation"""
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import BICUBIC
|
||||||
|
|
||||||
|
from deepspeech.transform.functional import FuncTrans
|
||||||
|
|
||||||
|
|
||||||
|
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
|
||||||
|
"""time warp for spec augment
|
||||||
|
|
||||||
|
move random center frame by the random width ~ uniform(-window, window)
|
||||||
|
:param numpy.ndarray x: spectrogram (time, freq)
|
||||||
|
:param int max_time_warp: maximum time frames to warp
|
||||||
|
:param bool inplace: overwrite x with the result
|
||||||
|
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
|
||||||
|
(slow, differentiable)
|
||||||
|
:returns numpy.ndarray: time warped spectrogram (time, freq)
|
||||||
|
"""
|
||||||
|
window = max_time_warp
|
||||||
|
if mode == "PIL":
|
||||||
|
t = x.shape[0]
|
||||||
|
if t - window <= window:
|
||||||
|
return x
|
||||||
|
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
|
||||||
|
center = random.randrange(window, t - window)
|
||||||
|
warped = random.randrange(center - window, center +
|
||||||
|
window) + 1 # 1 ... t - 1
|
||||||
|
|
||||||
|
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
|
||||||
|
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
|
||||||
|
BICUBIC)
|
||||||
|
if inplace:
|
||||||
|
x[:warped] = left
|
||||||
|
x[warped:] = right
|
||||||
|
return x
|
||||||
|
return numpy.concatenate((left, right), 0)
|
||||||
|
elif mode == "sparse_image_warp":
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from espnet.utils import spec_augment
|
||||||
|
|
||||||
|
# TODO(karita): make this differentiable again
|
||||||
|
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("unknown resize mode: " + mode +
|
||||||
|
", choose one from (PIL, sparse_image_warp).")
|
||||||
|
|
||||||
|
|
||||||
|
class TimeWarp(FuncTrans):
|
||||||
|
_func = time_warp
|
||||||
|
__doc__ = time_warp.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
|
||||||
|
"""freq mask for spec agument
|
||||||
|
|
||||||
|
:param numpy.ndarray x: (time, freq)
|
||||||
|
:param int n_mask: the number of masks
|
||||||
|
:param bool inplace: overwrite
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
if inplace:
|
||||||
|
cloned = x
|
||||||
|
else:
|
||||||
|
cloned = x.copy()
|
||||||
|
|
||||||
|
num_mel_channels = cloned.shape[1]
|
||||||
|
fs = numpy.random.randint(0, F, size=(n_mask, 2))
|
||||||
|
|
||||||
|
for f, mask_end in fs:
|
||||||
|
f_zero = random.randrange(0, num_mel_channels - f)
|
||||||
|
mask_end += f_zero
|
||||||
|
|
||||||
|
# avoids randrange error if values are equal and range is empty
|
||||||
|
if f_zero == f_zero + f:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if replace_with_zero:
|
||||||
|
cloned[:, f_zero:mask_end] = 0
|
||||||
|
else:
|
||||||
|
cloned[:, f_zero:mask_end] = cloned.mean()
|
||||||
|
return cloned
|
||||||
|
|
||||||
|
|
||||||
|
class FreqMask(FuncTrans):
|
||||||
|
_func = freq_mask
|
||||||
|
__doc__ = freq_mask.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
|
||||||
|
"""freq mask for spec agument
|
||||||
|
|
||||||
|
:param numpy.ndarray spec: (time, freq)
|
||||||
|
:param int n_mask: the number of masks
|
||||||
|
:param bool inplace: overwrite
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
if inplace:
|
||||||
|
cloned = spec
|
||||||
|
else:
|
||||||
|
cloned = spec.copy()
|
||||||
|
len_spectro = cloned.shape[0]
|
||||||
|
ts = numpy.random.randint(0, T, size=(n_mask, 2))
|
||||||
|
for t, mask_end in ts:
|
||||||
|
# avoid randint range error
|
||||||
|
if len_spectro - t <= 0:
|
||||||
|
continue
|
||||||
|
t_zero = random.randrange(0, len_spectro - t)
|
||||||
|
|
||||||
|
# avoids randrange error if values are equal and range is empty
|
||||||
|
if t_zero == t_zero + t:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mask_end += t_zero
|
||||||
|
if replace_with_zero:
|
||||||
|
cloned[t_zero:mask_end] = 0
|
||||||
|
else:
|
||||||
|
cloned[t_zero:mask_end] = cloned.mean()
|
||||||
|
return cloned
|
||||||
|
|
||||||
|
|
||||||
|
class TimeMask(FuncTrans):
|
||||||
|
_func = time_mask
|
||||||
|
__doc__ = time_mask.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
||||||
|
|
||||||
|
|
||||||
|
def spec_augment(
|
||||||
|
x,
|
||||||
|
resize_mode="PIL",
|
||||||
|
max_time_warp=80,
|
||||||
|
max_freq_width=27,
|
||||||
|
n_freq_mask=2,
|
||||||
|
max_time_width=100,
|
||||||
|
n_time_mask=2,
|
||||||
|
inplace=True,
|
||||||
|
replace_with_zero=True, ):
|
||||||
|
"""spec agument
|
||||||
|
|
||||||
|
apply random time warping and time/freq masking
|
||||||
|
default setting is based on LD (Librispeech double) in Table 2
|
||||||
|
https://arxiv.org/pdf/1904.08779.pdf
|
||||||
|
|
||||||
|
:param numpy.ndarray x: (time, freq)
|
||||||
|
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
|
||||||
|
(slow, differentiable)
|
||||||
|
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
|
||||||
|
:param int freq_mask_width: maximum width of the random freq mask (F)
|
||||||
|
:param int n_freq_mask: the number of the random freq mask (m_F)
|
||||||
|
:param int time_mask_width: maximum width of the random time mask (T)
|
||||||
|
:param int n_time_mask: the number of the random time mask (m_T)
|
||||||
|
:param bool inplace: overwrite intermediate array
|
||||||
|
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||||
|
"""
|
||||||
|
assert isinstance(x, numpy.ndarray)
|
||||||
|
assert x.ndim == 2
|
||||||
|
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
|
||||||
|
x = freq_mask(
|
||||||
|
x,
|
||||||
|
max_freq_width,
|
||||||
|
n_freq_mask,
|
||||||
|
inplace=inplace,
|
||||||
|
replace_with_zero=replace_with_zero, )
|
||||||
|
x = time_mask(
|
||||||
|
x,
|
||||||
|
max_time_width,
|
||||||
|
n_time_mask,
|
||||||
|
inplace=inplace,
|
||||||
|
replace_with_zero=replace_with_zero, )
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpecAugment(FuncTrans):
|
||||||
|
_func = spec_augment
|
||||||
|
__doc__ = spec_augment.__doc__
|
||||||
|
|
||||||
|
def __call__(self, x, train):
|
||||||
|
if not train:
|
||||||
|
return x
|
||||||
|
return super().__call__(x)
|
@ -0,0 +1,305 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def stft(x,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect"):
|
||||||
|
# x: [Time, Channel]
|
||||||
|
if x.ndim == 1:
|
||||||
|
single_channel = True
|
||||||
|
# x: [Time] -> [Time, Channel]
|
||||||
|
x = x[:, None]
|
||||||
|
else:
|
||||||
|
single_channel = False
|
||||||
|
x = x.astype(np.float32)
|
||||||
|
|
||||||
|
# FIXME(kamo): librosa.stft can't use multi-channel?
|
||||||
|
# x: [Time, Channel, Freq]
|
||||||
|
x = np.stack(
|
||||||
|
[
|
||||||
|
librosa.stft(
|
||||||
|
x[:, ch],
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode, ).T for ch in range(x.shape[1])
|
||||||
|
],
|
||||||
|
axis=1, )
|
||||||
|
|
||||||
|
if single_channel:
|
||||||
|
# x: [Time, Channel, Freq] -> [Time, Freq]
|
||||||
|
x = x[:, 0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def istft(x, n_shift, win_length=None, window="hann", center=True):
|
||||||
|
# x: [Time, Channel, Freq]
|
||||||
|
if x.ndim == 2:
|
||||||
|
single_channel = True
|
||||||
|
# x: [Time, Freq] -> [Time, Channel, Freq]
|
||||||
|
x = x[:, None, :]
|
||||||
|
else:
|
||||||
|
single_channel = False
|
||||||
|
|
||||||
|
# x: [Time, Channel]
|
||||||
|
x = np.stack(
|
||||||
|
[
|
||||||
|
librosa.istft(
|
||||||
|
x[:, ch].T, # [Time, Freq] -> [Freq, Time]
|
||||||
|
hop_length=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
center=center, ) for ch in range(x.shape[1])
|
||||||
|
],
|
||||||
|
axis=1, )
|
||||||
|
|
||||||
|
if single_channel:
|
||||||
|
# x: [Time, Channel] -> [Time]
|
||||||
|
x = x[:, 0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def stft2logmelspectrogram(x_stft,
|
||||||
|
fs,
|
||||||
|
n_mels,
|
||||||
|
n_fft,
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10):
|
||||||
|
# x_stft: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
fmin = 0 if fmin is None else fmin
|
||||||
|
fmax = fs / 2 if fmax is None else fmax
|
||||||
|
|
||||||
|
# spc: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
spc = np.abs(x_stft)
|
||||||
|
# mel_basis: (Mel_freq, Freq)
|
||||||
|
mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax)
|
||||||
|
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
|
||||||
|
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
||||||
|
|
||||||
|
return lmspc
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
|
||||||
|
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
|
||||||
|
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
|
||||||
|
return spc
|
||||||
|
|
||||||
|
|
||||||
|
def logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs,
|
||||||
|
n_mels,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10,
|
||||||
|
pad_mode="reflect", ):
|
||||||
|
# stft: (Time, Channel, Freq) or (Time, Freq)
|
||||||
|
x_stft = stft(
|
||||||
|
x,
|
||||||
|
n_fft=n_fft,
|
||||||
|
n_shift=n_shift,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
pad_mode=pad_mode, )
|
||||||
|
|
||||||
|
return stft2logmelspectrogram(
|
||||||
|
x_stft,
|
||||||
|
fs=fs,
|
||||||
|
n_mels=n_mels,
|
||||||
|
n_fft=n_fft,
|
||||||
|
fmin=fmin,
|
||||||
|
fmax=fmax,
|
||||||
|
eps=eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Spectrogram():
|
||||||
|
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return spectrogram(
|
||||||
|
x,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window, )
|
||||||
|
|
||||||
|
|
||||||
|
class LogMelSpectrogram():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs,
|
||||||
|
n_mels,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10, ):
|
||||||
|
self.fs = fs
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||||
|
"n_shift={n_shift}, win_length={win_length}, window={window}, "
|
||||||
|
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax,
|
||||||
|
eps=self.eps, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window, )
|
||||||
|
|
||||||
|
|
||||||
|
class Stft2LogMelSpectrogram():
|
||||||
|
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||||
|
self.fs = fs
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||||
|
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax,
|
||||||
|
eps=self.eps, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return stft2logmelspectrogram(
|
||||||
|
x,
|
||||||
|
fs=self.fs,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
fmin=self.fmin,
|
||||||
|
fmax=self.fmax, )
|
||||||
|
|
||||||
|
|
||||||
|
class Stft():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_fft,
|
||||||
|
n_shift,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect", ):
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.center = center
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window},"
|
||||||
|
"center={center}, pad_mode={pad_mode})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center,
|
||||||
|
pad_mode=self.pad_mode, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return stft(
|
||||||
|
x,
|
||||||
|
self.n_fft,
|
||||||
|
self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center,
|
||||||
|
pad_mode=self.pad_mode, )
|
||||||
|
|
||||||
|
|
||||||
|
class IStft():
|
||||||
|
def __init__(self, n_shift, win_length=None, window="hann", center=True):
|
||||||
|
self.n_shift = n_shift
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.center = center
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(n_shift={n_shift}, "
|
||||||
|
"win_length={win_length}, window={window},"
|
||||||
|
"center={center})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
n_shift=self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center, ))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return istft(
|
||||||
|
x,
|
||||||
|
self.n_shift,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=self.window,
|
||||||
|
center=self.center, )
|
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# TODO(karita): add this to all the transform impl.
|
||||||
|
class TransformInterface:
|
||||||
|
"""Transform Interface"""
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
raise NotImplementedError("__call__ method is not implemented")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser):
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + "()"
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(TransformInterface):
|
||||||
|
"""Identity Function"""
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
@ -0,0 +1,156 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Transformation module."""
|
||||||
|
import copy
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from deepspeech.utils.dynamic_import import dynamic_import
|
||||||
|
|
||||||
|
# TODO(karita): inherit TransformInterface
|
||||||
|
# TODO(karita): register cmd arguments in asr_train.py
|
||||||
|
import_alias = dict(
|
||||||
|
identity="deepspeech.transform.transform_interface:Identity",
|
||||||
|
time_warp="deepspeech.transform.spec_augment:TimeWarp",
|
||||||
|
time_mask="deepspeech.transform.spec_augment:TimeMask",
|
||||||
|
freq_mask="deepspeech.transform.spec_augment:FreqMask",
|
||||||
|
spec_augment="deepspeech.transform.spec_augment:SpecAugment",
|
||||||
|
speed_perturbation="deepspeech.transform.perturb:SpeedPerturbation",
|
||||||
|
volume_perturbation="deepspeech.transform.perturb:VolumePerturbation",
|
||||||
|
noise_injection="deepspeech.transform.perturb:NoiseInjection",
|
||||||
|
bandpass_perturbation="deepspeech.transform.perturb:BandpassPerturbation",
|
||||||
|
rir_convolve="deepspeech.transform.perturb:RIRConvolve",
|
||||||
|
delta="deepspeech.transform.add_deltas:AddDeltas",
|
||||||
|
cmvn="deepspeech.transform.cmvn:CMVN",
|
||||||
|
utterance_cmvn="deepspeech.transform.cmvn:UtteranceCMVN",
|
||||||
|
fbank="deepspeech.transform.spectrogram:LogMelSpectrogram",
|
||||||
|
spectrogram="deepspeech.transform.spectrogram:Spectrogram",
|
||||||
|
stft="deepspeech.transform.spectrogram:Stft",
|
||||||
|
istft="deepspeech.transform.spectrogram:IStft",
|
||||||
|
stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram",
|
||||||
|
wpe="deepspeech.transform.wpe:WPE",
|
||||||
|
channel_selector="deepspeech.transform.channel_selector:ChannelSelector", )
|
||||||
|
|
||||||
|
|
||||||
|
class Transformation():
|
||||||
|
"""Apply some functions to the mini-batch
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kwargs = {"process": [{"type": "fbank",
|
||||||
|
... "n_mels": 80,
|
||||||
|
... "fs": 16000},
|
||||||
|
... {"type": "cmvn",
|
||||||
|
... "stats": "data/train/cmvn.ark",
|
||||||
|
... "norm_vars": True},
|
||||||
|
... {"type": "delta", "window": 2, "order": 2}]}
|
||||||
|
>>> transform = Transformation(kwargs)
|
||||||
|
>>> bs = 10
|
||||||
|
>>> xs = [np.random.randn(100, 80).astype(np.float32)
|
||||||
|
... for _ in range(bs)]
|
||||||
|
>>> xs = transform(xs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conffile=None):
|
||||||
|
if conffile is not None:
|
||||||
|
if isinstance(conffile, dict):
|
||||||
|
self.conf = copy.deepcopy(conffile)
|
||||||
|
else:
|
||||||
|
with io.open(conffile, encoding="utf-8") as f:
|
||||||
|
self.conf = yaml.safe_load(f)
|
||||||
|
assert isinstance(self.conf, dict), type(self.conf)
|
||||||
|
else:
|
||||||
|
self.conf = {"mode": "sequential", "process": []}
|
||||||
|
|
||||||
|
self.functions = OrderedDict()
|
||||||
|
if self.conf.get("mode", "sequential") == "sequential":
|
||||||
|
for idx, process in enumerate(self.conf["process"]):
|
||||||
|
assert isinstance(process, dict), type(process)
|
||||||
|
opts = dict(process)
|
||||||
|
process_type = opts.pop("type")
|
||||||
|
class_obj = dynamic_import(process_type, import_alias)
|
||||||
|
# TODO(karita): assert issubclass(class_obj, TransformInterface)
|
||||||
|
try:
|
||||||
|
self.functions[idx] = class_obj(**opts)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
signa = signature(class_obj)
|
||||||
|
except ValueError:
|
||||||
|
# Some function, e.g. built-in function, are failed
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logging.error("Expected signature: {}({})".format(
|
||||||
|
class_obj.__name__, signa))
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not supporting mode={}".format(self.conf["mode"]))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
rep = "\n" + "\n".join(" {}: {}".format(k, v)
|
||||||
|
for k, v in self.functions.items())
|
||||||
|
return "{}({})".format(self.__class__.__name__, rep)
|
||||||
|
|
||||||
|
def __call__(self, xs, uttid_list=None, **kwargs):
|
||||||
|
"""Return new mini-batch
|
||||||
|
|
||||||
|
:param Union[Sequence[np.ndarray], np.ndarray] xs:
|
||||||
|
:param Union[Sequence[str], str] uttid_list:
|
||||||
|
:return: batch:
|
||||||
|
:rtype: List[np.ndarray]
|
||||||
|
"""
|
||||||
|
if not isinstance(xs, Sequence):
|
||||||
|
is_batch = False
|
||||||
|
xs = [xs]
|
||||||
|
else:
|
||||||
|
is_batch = True
|
||||||
|
|
||||||
|
if isinstance(uttid_list, str):
|
||||||
|
uttid_list = [uttid_list for _ in range(len(xs))]
|
||||||
|
|
||||||
|
if self.conf.get("mode", "sequential") == "sequential":
|
||||||
|
for idx in range(len(self.conf["process"])):
|
||||||
|
func = self.functions[idx]
|
||||||
|
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
|
||||||
|
# Derive only the args which the func has
|
||||||
|
try:
|
||||||
|
param = signature(func).parameters
|
||||||
|
except ValueError:
|
||||||
|
# Some function, e.g. built-in function, are failed
|
||||||
|
param = {}
|
||||||
|
_kwargs = {k: v for k, v in kwargs.items() if k in param}
|
||||||
|
try:
|
||||||
|
if uttid_list is not None and "uttid" in param:
|
||||||
|
xs = [
|
||||||
|
func(x, u, **_kwargs)
|
||||||
|
for x, u in zip(xs, uttid_list)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
xs = [func(x, **_kwargs) for x in xs]
|
||||||
|
except Exception:
|
||||||
|
logging.fatal("Catch a exception from {}th func: {}".format(
|
||||||
|
idx, func))
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not supporting mode={}".format(self.conf["mode"]))
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
return xs
|
||||||
|
else:
|
||||||
|
return xs[0]
|
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from nara_wpe.wpe import wpe
|
||||||
|
|
||||||
|
|
||||||
|
class WPE(object):
|
||||||
|
def __init__(self,
|
||||||
|
taps=10,
|
||||||
|
delay=3,
|
||||||
|
iterations=3,
|
||||||
|
psd_context=0,
|
||||||
|
statistics_mode="full"):
|
||||||
|
self.taps = taps
|
||||||
|
self.delay = delay
|
||||||
|
self.iterations = iterations
|
||||||
|
self.psd_context = psd_context
|
||||||
|
self.statistics_mode = statistics_mode
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ("{name}(taps={taps}, delay={delay}"
|
||||||
|
"iterations={iterations}, psd_context={psd_context}, "
|
||||||
|
"statistics_mode={statistics_mode})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
taps=self.taps,
|
||||||
|
delay=self.delay,
|
||||||
|
iterations=self.iterations,
|
||||||
|
psd_context=self.psd_context,
|
||||||
|
statistics_mode=self.statistics_mode, ))
|
||||||
|
|
||||||
|
def __call__(self, xs):
|
||||||
|
"""Return enhanced
|
||||||
|
|
||||||
|
:param np.ndarray xs: (Time, Channel, Frequency)
|
||||||
|
:return: enhanced_xs
|
||||||
|
:rtype: np.ndarray
|
||||||
|
|
||||||
|
"""
|
||||||
|
# nara_wpe.wpe: (F, C, T)
|
||||||
|
xs = wpe(
|
||||||
|
xs.transpose((2, 1, 0)),
|
||||||
|
taps=self.taps,
|
||||||
|
delay=self.delay,
|
||||||
|
iterations=self.iterations,
|
||||||
|
psd_context=self.psd_context,
|
||||||
|
statistics_mode=self.statistics_mode, )
|
||||||
|
return xs.transpose(2, 1, 0)
|
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = ["label_smoothing_dist"]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(takaaki-hori): add different smoothing methods
|
||||||
|
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||||
|
"""Obtain label distribution for loss smoothing.
|
||||||
|
|
||||||
|
:param odim:
|
||||||
|
:param lsm_type:
|
||||||
|
:param blank:
|
||||||
|
:param transcript:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if transcript is not None:
|
||||||
|
with open(transcript, "rb") as f:
|
||||||
|
trans_json = json.load(f)["utts"]
|
||||||
|
|
||||||
|
if lsm_type == "unigram":
|
||||||
|
assert transcript is not None, (
|
||||||
|
"transcript is required for %s label smoothing" % lsm_type)
|
||||||
|
labelcount = np.zeros(odim)
|
||||||
|
for k, v in trans_json.items():
|
||||||
|
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
|
||||||
|
# to avoid an error when there is no text in an uttrance
|
||||||
|
if len(ids) > 0:
|
||||||
|
labelcount[ids] += 1
|
||||||
|
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||||
|
labelcount[labelcount == 0] = 1 # flooring
|
||||||
|
labelcount[blank] = 0 # remove counts for blank
|
||||||
|
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||||
|
else:
|
||||||
|
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
return labeldist
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def check_kwargs(func, kwargs, name=None):
|
||||||
|
"""check kwargs are valid for func
|
||||||
|
|
||||||
|
If kwargs are invalid, raise TypeError as same as python default
|
||||||
|
:param function func: function to be validated
|
||||||
|
:param dict kwargs: keyword arguments for func
|
||||||
|
:param str name: name used in TypeError (default is func name)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
if name is None:
|
||||||
|
name = func.__name__
|
||||||
|
for k in kwargs.keys():
|
||||||
|
if k not in params:
|
||||||
|
raise TypeError(
|
||||||
|
f"{name}() got an unexpected keyword argument '{k}'")
|
@ -0,0 +1,241 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import kaldiio
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from deepspeech.io.reader import SoundHDF5File
|
||||||
|
|
||||||
|
|
||||||
|
def file_reader_helper(
|
||||||
|
rspecifier: str,
|
||||||
|
filetype: str="mat",
|
||||||
|
return_shape: bool=False,
|
||||||
|
segments: str=None, ):
|
||||||
|
"""Read uttid and array in kaldi style
|
||||||
|
|
||||||
|
This function might be a bit confusing as "ark" is used
|
||||||
|
for HDF5 to imitate "kaldi-rspecifier".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rspecifier: Give as "ark:feats.ark" or "scp:feats.scp"
|
||||||
|
filetype: "mat" is kaldi-martix, "hdf5": HDF5
|
||||||
|
return_shape: Return the shape of the matrix,
|
||||||
|
instead of the matrix. This can reduce IO cost for HDF5.
|
||||||
|
segments (str): The file format is
|
||||||
|
"<segment-id> <recording-id> <start-time> <end-time>\n"
|
||||||
|
"e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5\n"
|
||||||
|
Returns:
|
||||||
|
Generator[Tuple[str, np.ndarray], None, None]:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Read from kaldi-matrix ark file:
|
||||||
|
|
||||||
|
>>> for u, array in file_reader_helper('ark:feats.ark', 'mat'):
|
||||||
|
... array
|
||||||
|
|
||||||
|
Read from HDF5 file:
|
||||||
|
|
||||||
|
>>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'):
|
||||||
|
... array
|
||||||
|
|
||||||
|
"""
|
||||||
|
if filetype == "mat":
|
||||||
|
return KaldiReader(
|
||||||
|
rspecifier, return_shape=return_shape, segments=segments)
|
||||||
|
elif filetype == "hdf5":
|
||||||
|
return HDF5Reader(rspecifier, return_shape=return_shape)
|
||||||
|
elif filetype == "sound.hdf5":
|
||||||
|
return SoundHDF5Reader(rspecifier, return_shape=return_shape)
|
||||||
|
elif filetype == "sound":
|
||||||
|
return SoundReader(rspecifier, return_shape=return_shape)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"filetype={filetype}")
|
||||||
|
|
||||||
|
|
||||||
|
class KaldiReader:
|
||||||
|
def __init__(self, rspecifier, return_shape=False, segments=None):
|
||||||
|
self.rspecifier = rspecifier
|
||||||
|
self.return_shape = return_shape
|
||||||
|
self.segments = segments
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
with kaldiio.ReadHelper(
|
||||||
|
self.rspecifier, segments=self.segments) as reader:
|
||||||
|
for key, array in reader:
|
||||||
|
if self.return_shape:
|
||||||
|
array = array.shape
|
||||||
|
yield key, array
|
||||||
|
|
||||||
|
|
||||||
|
class HDF5Reader:
|
||||||
|
def __init__(self, rspecifier, return_shape=False):
|
||||||
|
if ":" not in rspecifier:
|
||||||
|
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'.
|
||||||
|
format(self.rspecifier))
|
||||||
|
self.rspecifier = rspecifier
|
||||||
|
self.ark_or_scp, self.filepath = self.rspecifier.split(":", 1)
|
||||||
|
if self.ark_or_scp not in ["ark", "scp"]:
|
||||||
|
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
|
||||||
|
|
||||||
|
self.return_shape = return_shape
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.ark_or_scp == "scp":
|
||||||
|
hdf5_dict = {}
|
||||||
|
with open(self.filepath, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
key, value = line.rstrip().split(None, 1)
|
||||||
|
|
||||||
|
if ":" not in value:
|
||||||
|
raise RuntimeError(
|
||||||
|
"scp file for hdf5 should be like: "
|
||||||
|
'"uttid filepath.h5:key": {}({})'.format(
|
||||||
|
line, self.filepath))
|
||||||
|
path, h5_key = value.split(":", 1)
|
||||||
|
|
||||||
|
hdf5_file = hdf5_dict.get(path)
|
||||||
|
if hdf5_file is None:
|
||||||
|
try:
|
||||||
|
hdf5_file = h5py.File(path, "r")
|
||||||
|
except Exception:
|
||||||
|
logging.error("Error when loading {}".format(path))
|
||||||
|
raise
|
||||||
|
hdf5_dict[path] = hdf5_file
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = hdf5_file[h5_key]
|
||||||
|
except Exception:
|
||||||
|
logging.error("Error when loading {} with key={}".
|
||||||
|
format(path, h5_key))
|
||||||
|
raise
|
||||||
|
|
||||||
|
if self.return_shape:
|
||||||
|
yield key, data.shape
|
||||||
|
else:
|
||||||
|
yield key, data[()]
|
||||||
|
|
||||||
|
# Closing all files
|
||||||
|
for k in hdf5_dict:
|
||||||
|
try:
|
||||||
|
hdf5_dict[k].close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.filepath == "-":
|
||||||
|
# Required h5py>=2.9
|
||||||
|
filepath = io.BytesIO(sys.stdin.buffer.read())
|
||||||
|
else:
|
||||||
|
filepath = self.filepath
|
||||||
|
with h5py.File(filepath, "r") as f:
|
||||||
|
for key in f:
|
||||||
|
if self.return_shape:
|
||||||
|
yield key, f[key].shape
|
||||||
|
else:
|
||||||
|
yield key, f[key][()]
|
||||||
|
|
||||||
|
|
||||||
|
class SoundHDF5Reader:
|
||||||
|
def __init__(self, rspecifier, return_shape=False):
|
||||||
|
if ":" not in rspecifier:
|
||||||
|
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'.
|
||||||
|
format(rspecifier))
|
||||||
|
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
|
||||||
|
if self.ark_or_scp not in ["ark", "scp"]:
|
||||||
|
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
|
||||||
|
self.return_shape = return_shape
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.ark_or_scp == "scp":
|
||||||
|
hdf5_dict = {}
|
||||||
|
with open(self.filepath, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
key, value = line.rstrip().split(None, 1)
|
||||||
|
|
||||||
|
if ":" not in value:
|
||||||
|
raise RuntimeError(
|
||||||
|
"scp file for hdf5 should be like: "
|
||||||
|
'"uttid filepath.h5:key": {}({})'.format(
|
||||||
|
line, self.filepath))
|
||||||
|
path, h5_key = value.split(":", 1)
|
||||||
|
|
||||||
|
hdf5_file = hdf5_dict.get(path)
|
||||||
|
if hdf5_file is None:
|
||||||
|
try:
|
||||||
|
hdf5_file = SoundHDF5File(path, "r")
|
||||||
|
except Exception:
|
||||||
|
logging.error("Error when loading {}".format(path))
|
||||||
|
raise
|
||||||
|
hdf5_dict[path] = hdf5_file
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = hdf5_file[h5_key]
|
||||||
|
except Exception:
|
||||||
|
logging.error("Error when loading {} with key={}".
|
||||||
|
format(path, h5_key))
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
|
||||||
|
# (soundfile style -> scipy style)
|
||||||
|
array, rate = data
|
||||||
|
if self.return_shape:
|
||||||
|
array = array.shape
|
||||||
|
yield key, (rate, array)
|
||||||
|
|
||||||
|
# Closing all files
|
||||||
|
for k in hdf5_dict:
|
||||||
|
try:
|
||||||
|
hdf5_dict[k].close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.filepath == "-":
|
||||||
|
# Required h5py>=2.9
|
||||||
|
filepath = io.BytesIO(sys.stdin.buffer.read())
|
||||||
|
else:
|
||||||
|
filepath = self.filepath
|
||||||
|
for key, (a, r) in SoundHDF5File(filepath, "r").items():
|
||||||
|
if self.return_shape:
|
||||||
|
a = a.shape
|
||||||
|
yield key, (r, a)
|
||||||
|
|
||||||
|
|
||||||
|
class SoundReader:
|
||||||
|
def __init__(self, rspecifier, return_shape=False):
|
||||||
|
if ":" not in rspecifier:
|
||||||
|
raise ValueError('Give "rspecifier" such as "scp:some.scp: {}"'.
|
||||||
|
format(rspecifier))
|
||||||
|
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
|
||||||
|
if self.ark_or_scp != "scp":
|
||||||
|
raise ValueError('Only supporting "scp" for sound file: {}'.format(
|
||||||
|
self.ark_or_scp))
|
||||||
|
self.return_shape = return_shape
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
with open(self.filepath, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
key, sound_file_path = line.rstrip().split(None, 1)
|
||||||
|
# Assume PCM16
|
||||||
|
array, rate = soundfile.read(sound_file_path, dtype="int16")
|
||||||
|
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
|
||||||
|
# (soundfile style -> scipy style)
|
||||||
|
if self.return_shape:
|
||||||
|
array = array.shape
|
||||||
|
yield key, (rate, array)
|
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from distutils.util import strtobool as dist_strtobool
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
def strtobool(x):
|
||||||
|
# distutils.util.strtobool returns integer, but it's confusing,
|
||||||
|
return bool(dist_strtobool(x))
|
||||||
|
|
||||||
|
|
||||||
|
def get_commandline_args():
|
||||||
|
extra_chars = [
|
||||||
|
" ",
|
||||||
|
";",
|
||||||
|
"&",
|
||||||
|
"(",
|
||||||
|
")",
|
||||||
|
"|",
|
||||||
|
"^",
|
||||||
|
"<",
|
||||||
|
">",
|
||||||
|
"?",
|
||||||
|
"*",
|
||||||
|
"[",
|
||||||
|
"]",
|
||||||
|
"$",
|
||||||
|
"`",
|
||||||
|
'"',
|
||||||
|
"\\",
|
||||||
|
"!",
|
||||||
|
"{",
|
||||||
|
"}",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Escape the extra characters for shell
|
||||||
|
argv = [
|
||||||
|
arg.replace("'", "'\\''") if all(char not in arg
|
||||||
|
for char in extra_chars) else
|
||||||
|
"'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv
|
||||||
|
]
|
||||||
|
|
||||||
|
return sys.executable + " " + " ".join(argv)
|
||||||
|
|
||||||
|
|
||||||
|
def is_scipy_wav_style(value):
|
||||||
|
# If Tuple[int, numpy.ndarray] or not
|
||||||
|
return (isinstance(value, Sequence) and len(value) == 2 and
|
||||||
|
isinstance(value[0], int) and isinstance(value[1], numpy.ndarray))
|
||||||
|
|
||||||
|
|
||||||
|
def assert_scipy_wav_style(value):
|
||||||
|
assert is_scipy_wav_style(
|
||||||
|
value), "Must be Tuple[int, numpy.ndarray], but got {}".format(
|
||||||
|
type(value) if not isinstance(value, Sequence) else "{}[{}]".format(
|
||||||
|
type(value), ", ".join(str(type(v)) for v in value)))
|
@ -0,0 +1,293 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import kaldiio
|
||||||
|
import numpy
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
from deepspeech.io.reader import SoundHDF5File
|
||||||
|
from deepspeech.utils.cli_utils import assert_scipy_wav_style
|
||||||
|
|
||||||
|
|
||||||
|
def file_writer_helper(
|
||||||
|
wspecifier: str,
|
||||||
|
filetype: str="mat",
|
||||||
|
write_num_frames: str=None,
|
||||||
|
compress: bool=False,
|
||||||
|
compression_method: int=2,
|
||||||
|
pcm_format: str="wav", ):
|
||||||
|
"""Write matrices in kaldi style
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wspecifier: e.g. ark,scp:out.ark,out.scp
|
||||||
|
filetype: "mat" is kaldi-martix, "hdf5": HDF5
|
||||||
|
write_num_frames: e.g. 'ark,t:num_frames.txt'
|
||||||
|
compress: Compress or not
|
||||||
|
compression_method: Specify compression level
|
||||||
|
|
||||||
|
Write in kaldi-matrix-ark with "kaldi-scp" file:
|
||||||
|
|
||||||
|
>>> with file_writer_helper('ark,scp:out.ark,out.scp') as f:
|
||||||
|
>>> f['uttid'] = array
|
||||||
|
|
||||||
|
This "scp" has the following format:
|
||||||
|
|
||||||
|
uttidA out.ark:1234
|
||||||
|
uttidB out.ark:2222
|
||||||
|
|
||||||
|
where, 1234 and 2222 points the strating byte address of the matrix.
|
||||||
|
(For detail, see official documentation of Kaldi)
|
||||||
|
|
||||||
|
Write in HDF5 with "scp" file:
|
||||||
|
|
||||||
|
>>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f:
|
||||||
|
>>> f['uttid'] = array
|
||||||
|
|
||||||
|
This "scp" file is created as:
|
||||||
|
|
||||||
|
uttidA out.h5:uttidA
|
||||||
|
uttidB out.h5:uttidB
|
||||||
|
|
||||||
|
HDF5 can be, unlike "kaldi-ark", accessed to any keys,
|
||||||
|
so originally "scp" is not required for random-reading.
|
||||||
|
Nevertheless we create "scp" for HDF5 because it is useful
|
||||||
|
for some use-case. e.g. Concatenation, Splitting.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if filetype == "mat":
|
||||||
|
return KaldiWriter(
|
||||||
|
wspecifier,
|
||||||
|
write_num_frames=write_num_frames,
|
||||||
|
compress=compress,
|
||||||
|
compression_method=compression_method, )
|
||||||
|
elif filetype == "hdf5":
|
||||||
|
return HDF5Writer(
|
||||||
|
wspecifier, write_num_frames=write_num_frames, compress=compress)
|
||||||
|
elif filetype == "sound.hdf5":
|
||||||
|
return SoundHDF5Writer(
|
||||||
|
wspecifier,
|
||||||
|
write_num_frames=write_num_frames,
|
||||||
|
pcm_format=pcm_format)
|
||||||
|
elif filetype == "sound":
|
||||||
|
return SoundWriter(
|
||||||
|
wspecifier,
|
||||||
|
write_num_frames=write_num_frames,
|
||||||
|
pcm_format=pcm_format)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"filetype={filetype}")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseWriter:
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
try:
|
||||||
|
self.writer.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.writer_scp is not None:
|
||||||
|
try:
|
||||||
|
self.writer_scp.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.writer_nframe is not None:
|
||||||
|
try:
|
||||||
|
self.writer_nframe.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_frames_writer(write_num_frames: str):
|
||||||
|
"""get_num_frames_writer
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> get_num_frames_writer('ark,t:num_frames.txt')
|
||||||
|
"""
|
||||||
|
if write_num_frames is not None:
|
||||||
|
if ":" not in write_num_frames:
|
||||||
|
raise ValueError('Must include ":", write_num_frames={}'.format(
|
||||||
|
write_num_frames))
|
||||||
|
|
||||||
|
nframes_type, nframes_file = write_num_frames.split(":", 1)
|
||||||
|
if nframes_type != "ark,t":
|
||||||
|
raise ValueError("Only supporting text mode. "
|
||||||
|
"e.g. --write-num-frames=ark,t:foo.txt :"
|
||||||
|
"{}".format(nframes_type))
|
||||||
|
|
||||||
|
return open(nframes_file, "w", encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class KaldiWriter(BaseWriter):
|
||||||
|
def __init__(self,
|
||||||
|
wspecifier,
|
||||||
|
write_num_frames=None,
|
||||||
|
compress=False,
|
||||||
|
compression_method=2):
|
||||||
|
if compress:
|
||||||
|
self.writer = kaldiio.WriteHelper(
|
||||||
|
wspecifier, compression_method=compression_method)
|
||||||
|
else:
|
||||||
|
self.writer = kaldiio.WriteHelper(wspecifier)
|
||||||
|
self.writer_scp = None
|
||||||
|
if write_num_frames is not None:
|
||||||
|
self.writer_nframe = get_num_frames_writer(write_num_frames)
|
||||||
|
else:
|
||||||
|
self.writer_nframe = None
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.writer[key] = value
|
||||||
|
if self.writer_nframe is not None:
|
||||||
|
self.writer_nframe.write(f"{key} {len(value)}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_wspecifier(wspecifier: str) -> Dict[str, str]:
|
||||||
|
"""Parse wspecifier to dict
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> parse_wspecifier('ark,scp:out.ark,out.scp')
|
||||||
|
{'ark': 'out.ark', 'scp': 'out.scp'}
|
||||||
|
|
||||||
|
"""
|
||||||
|
ark_scp, filepath = wspecifier.split(":", 1)
|
||||||
|
if ark_scp not in ["ark", "scp,ark", "ark,scp"]:
|
||||||
|
raise ValueError("{} is not allowed: {}".format(ark_scp, wspecifier))
|
||||||
|
ark_scps = ark_scp.split(",")
|
||||||
|
filepaths = filepath.split(",")
|
||||||
|
if len(ark_scps) != len(filepaths):
|
||||||
|
raise ValueError("Mismatch: {} and {}".format(ark_scp, filepath))
|
||||||
|
spec_dict = dict(zip(ark_scps, filepaths))
|
||||||
|
return spec_dict
|
||||||
|
|
||||||
|
|
||||||
|
class HDF5Writer(BaseWriter):
|
||||||
|
"""HDF5Writer
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> with HDF5Writer('ark:out.h5', compress=True) as f:
|
||||||
|
... f['key'] = array
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wspecifier, write_num_frames=None, compress=False):
|
||||||
|
spec_dict = parse_wspecifier(wspecifier)
|
||||||
|
self.filename = spec_dict["ark"]
|
||||||
|
|
||||||
|
if compress:
|
||||||
|
self.kwargs = {"compression": "gzip"}
|
||||||
|
else:
|
||||||
|
self.kwargs = {}
|
||||||
|
self.writer = h5py.File(spec_dict["ark"], "w")
|
||||||
|
if "scp" in spec_dict:
|
||||||
|
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
|
||||||
|
else:
|
||||||
|
self.writer_scp = None
|
||||||
|
if write_num_frames is not None:
|
||||||
|
self.writer_nframe = get_num_frames_writer(write_num_frames)
|
||||||
|
else:
|
||||||
|
self.writer_nframe = None
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.writer.create_dataset(key, data=value, **self.kwargs)
|
||||||
|
|
||||||
|
if self.writer_scp is not None:
|
||||||
|
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
|
||||||
|
if self.writer_nframe is not None:
|
||||||
|
self.writer_nframe.write(f"{key} {len(value)}\n")
|
||||||
|
|
||||||
|
|
||||||
|
class SoundHDF5Writer(BaseWriter):
|
||||||
|
"""SoundHDF5Writer
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> fs = 16000
|
||||||
|
>>> with SoundHDF5Writer('ark:out.h5') as f:
|
||||||
|
... f['key'] = fs, array
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
|
||||||
|
self.pcm_format = pcm_format
|
||||||
|
spec_dict = parse_wspecifier(wspecifier)
|
||||||
|
self.filename = spec_dict["ark"]
|
||||||
|
self.writer = SoundHDF5File(
|
||||||
|
spec_dict["ark"], "w", format=self.pcm_format)
|
||||||
|
if "scp" in spec_dict:
|
||||||
|
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
|
||||||
|
else:
|
||||||
|
self.writer_scp = None
|
||||||
|
if write_num_frames is not None:
|
||||||
|
self.writer_nframe = get_num_frames_writer(write_num_frames)
|
||||||
|
else:
|
||||||
|
self.writer_nframe = None
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
assert_scipy_wav_style(value)
|
||||||
|
# Change Tuple[int, ndarray] -> Tuple[ndarray, int]
|
||||||
|
# (scipy style -> soundfile style)
|
||||||
|
value = (value[1], value[0])
|
||||||
|
self.writer.create_dataset(key, data=value)
|
||||||
|
|
||||||
|
if self.writer_scp is not None:
|
||||||
|
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
|
||||||
|
if self.writer_nframe is not None:
|
||||||
|
self.writer_nframe.write(f"{key} {len(value[0])}\n")
|
||||||
|
|
||||||
|
|
||||||
|
class SoundWriter(BaseWriter):
|
||||||
|
"""SoundWriter
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> fs = 16000
|
||||||
|
>>> with SoundWriter('ark,scp:outdir,out.scp') as f:
|
||||||
|
... f['key'] = fs, array
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
|
||||||
|
self.pcm_format = pcm_format
|
||||||
|
spec_dict = parse_wspecifier(wspecifier)
|
||||||
|
# e.g. ark,scp:dirname,wav.scp
|
||||||
|
# -> The wave files are found in dirname/*.wav
|
||||||
|
self.dirname = spec_dict["ark"]
|
||||||
|
Path(self.dirname).mkdir(parents=True, exist_ok=True)
|
||||||
|
self.writer = None
|
||||||
|
|
||||||
|
if "scp" in spec_dict:
|
||||||
|
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
|
||||||
|
else:
|
||||||
|
self.writer_scp = None
|
||||||
|
if write_num_frames is not None:
|
||||||
|
self.writer_nframe = get_num_frames_writer(write_num_frames)
|
||||||
|
else:
|
||||||
|
self.writer_nframe = None
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
assert_scipy_wav_style(value)
|
||||||
|
rate, signal = value
|
||||||
|
wavfile = Path(self.dirname) / (key + "." + self.pcm_format)
|
||||||
|
soundfile.write(wavfile, signal.astype(numpy.int16), rate)
|
||||||
|
|
||||||
|
if self.writer_scp is not None:
|
||||||
|
self.writer_scp.write(f"{key} {wavfile}\n")
|
||||||
|
if self.writer_nframe is not None:
|
||||||
|
self.writer_nframe.write(f"{key} {len(signal)}\n")
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
After Width: | Height: | Size: 72 KiB |
@ -0,0 +1,4 @@
|
|||||||
|
dump
|
||||||
|
fbank
|
||||||
|
exp
|
||||||
|
data
|
@ -1,122 +0,0 @@
|
|||||||
# https://yaml.org/type/float.html
|
|
||||||
data:
|
|
||||||
train_manifest: data/manifest.train
|
|
||||||
dev_manifest: data/manifest.dev
|
|
||||||
test_manifest: data/manifest.test
|
|
||||||
min_input_len: 0.5
|
|
||||||
max_input_len: 20.0
|
|
||||||
min_output_len: 0.0
|
|
||||||
max_output_len: 400.0
|
|
||||||
min_output_input_ratio: 0.05
|
|
||||||
max_output_input_ratio: 10.0
|
|
||||||
|
|
||||||
collator:
|
|
||||||
vocab_filepath: data/vocab.txt
|
|
||||||
unit_type: 'spm'
|
|
||||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
|
||||||
mean_std_filepath: ""
|
|
||||||
augmentation_config: conf/augmentation.json
|
|
||||||
batch_size: 16
|
|
||||||
raw_wav: True # use raw_wav or kaldi feature
|
|
||||||
spectrum_type: fbank #linear, mfcc, fbank
|
|
||||||
feat_dim: 80
|
|
||||||
delta_delta: False
|
|
||||||
dither: 1.0
|
|
||||||
target_sample_rate: 16000
|
|
||||||
max_freq: None
|
|
||||||
n_fft: None
|
|
||||||
stride_ms: 10.0
|
|
||||||
window_ms: 25.0
|
|
||||||
use_dB_normalization: True
|
|
||||||
target_dB: -20
|
|
||||||
random_seed: 0
|
|
||||||
keep_transcription_text: False
|
|
||||||
sortagrad: True
|
|
||||||
shuffle_method: batch_shuffle
|
|
||||||
num_workers: 2
|
|
||||||
|
|
||||||
|
|
||||||
# network architecture
|
|
||||||
model:
|
|
||||||
cmvn_file: "data/mean_std.json"
|
|
||||||
cmvn_file_type: "json"
|
|
||||||
# encoder related
|
|
||||||
encoder: conformer
|
|
||||||
encoder_conf:
|
|
||||||
output_size: 256 # dimension of attention
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048 # the number of units of position-wise feed forward
|
|
||||||
num_blocks: 12 # the number of encoder blocks
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
|
||||||
normalize_before: True
|
|
||||||
use_cnn_module: True
|
|
||||||
cnn_module_kernel: 15
|
|
||||||
activation_type: 'swish'
|
|
||||||
pos_enc_layer_type: 'rel_pos'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
causal: True
|
|
||||||
use_dynamic_chunk: true
|
|
||||||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
|
|
||||||
use_dynamic_left_chunk: false
|
|
||||||
|
|
||||||
# decoder related
|
|
||||||
decoder: transformer
|
|
||||||
decoder_conf:
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 6
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
self_attention_dropout_rate: 0.0
|
|
||||||
src_attention_dropout_rate: 0.0
|
|
||||||
|
|
||||||
# hybrid CTC/attention
|
|
||||||
model_conf:
|
|
||||||
ctc_weight: 0.3
|
|
||||||
ctc_dropoutrate: 0.0
|
|
||||||
ctc_grad_norm_type: null
|
|
||||||
lsm_weight: 0.1 # label smoothing option
|
|
||||||
length_normalized_loss: false
|
|
||||||
|
|
||||||
|
|
||||||
training:
|
|
||||||
n_epoch: 240
|
|
||||||
accum_grad: 8
|
|
||||||
global_grad_clip: 5.0
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.001
|
|
||||||
weight_decay: 1e-06
|
|
||||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
|
||||||
scheduler_conf:
|
|
||||||
warmup_steps: 25000
|
|
||||||
lr_decay: 1.0
|
|
||||||
log_interval: 100
|
|
||||||
checkpoint:
|
|
||||||
kbest_n: 50
|
|
||||||
latest_n: 5
|
|
||||||
|
|
||||||
|
|
||||||
decoding:
|
|
||||||
batch_size: 128
|
|
||||||
error_rate_type: wer
|
|
||||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
|
||||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
|
||||||
alpha: 2.5
|
|
||||||
beta: 0.3
|
|
||||||
beam_size: 10
|
|
||||||
cutoff_prob: 1.0
|
|
||||||
cutoff_top_n: 0
|
|
||||||
num_proc_bsearch: 8
|
|
||||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
|
||||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
|
||||||
# <0: for decoding, use full chunk.
|
|
||||||
# >0: for decoding, use fixed chunk size as set.
|
|
||||||
# 0: used for training, it's prohibited here.
|
|
||||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
|
||||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
|
||||||
|
|
||||||
|
|
@ -1,115 +0,0 @@
|
|||||||
# https://yaml.org/type/float.html
|
|
||||||
data:
|
|
||||||
train_manifest: data/manifest.train
|
|
||||||
dev_manifest: data/manifest.dev
|
|
||||||
test_manifest: data/manifest.test
|
|
||||||
min_input_len: 0.5 # second
|
|
||||||
max_input_len: 20.0 # second
|
|
||||||
min_output_len: 0.0 # tokens
|
|
||||||
max_output_len: 400.0 # tokens
|
|
||||||
min_output_input_ratio: 0.05
|
|
||||||
max_output_input_ratio: 10.0
|
|
||||||
|
|
||||||
collator:
|
|
||||||
vocab_filepath: data/vocab.txt
|
|
||||||
unit_type: 'spm'
|
|
||||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
|
||||||
mean_std_filepath: ""
|
|
||||||
augmentation_config: conf/augmentation.json
|
|
||||||
batch_size: 64
|
|
||||||
raw_wav: True # use raw_wav or kaldi feature
|
|
||||||
spectrum_type: fbank #linear, mfcc, fbank
|
|
||||||
feat_dim: 80
|
|
||||||
delta_delta: False
|
|
||||||
dither: 1.0
|
|
||||||
target_sample_rate: 16000
|
|
||||||
max_freq: None
|
|
||||||
n_fft: None
|
|
||||||
stride_ms: 10.0
|
|
||||||
window_ms: 25.0
|
|
||||||
use_dB_normalization: True
|
|
||||||
target_dB: -20
|
|
||||||
random_seed: 0
|
|
||||||
keep_transcription_text: False
|
|
||||||
sortagrad: True
|
|
||||||
shuffle_method: batch_shuffle
|
|
||||||
num_workers: 2
|
|
||||||
|
|
||||||
|
|
||||||
# network architecture
|
|
||||||
model:
|
|
||||||
cmvn_file: "data/mean_std.json"
|
|
||||||
cmvn_file_type: "json"
|
|
||||||
# encoder related
|
|
||||||
encoder: transformer
|
|
||||||
encoder_conf:
|
|
||||||
output_size: 256 # dimension of attention
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048 # the number of units of position-wise feed forward
|
|
||||||
num_blocks: 12 # the number of encoder blocks
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
|
||||||
normalize_before: true
|
|
||||||
use_dynamic_chunk: true
|
|
||||||
use_dynamic_left_chunk: false
|
|
||||||
|
|
||||||
# decoder related
|
|
||||||
decoder: transformer
|
|
||||||
decoder_conf:
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 6
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
self_attention_dropout_rate: 0.0
|
|
||||||
src_attention_dropout_rate: 0.0
|
|
||||||
|
|
||||||
# hybrid CTC/attention
|
|
||||||
model_conf:
|
|
||||||
ctc_weight: 0.3
|
|
||||||
ctc_dropoutrate: 0.0
|
|
||||||
ctc_grad_norm_type: null
|
|
||||||
lsm_weight: 0.1 # label smoothing option
|
|
||||||
length_normalized_loss: false
|
|
||||||
|
|
||||||
|
|
||||||
training:
|
|
||||||
n_epoch: 120
|
|
||||||
accum_grad: 1
|
|
||||||
global_grad_clip: 5.0
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.001
|
|
||||||
weight_decay: 1e-06
|
|
||||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
|
||||||
scheduler_conf:
|
|
||||||
warmup_steps: 25000
|
|
||||||
lr_decay: 1.0
|
|
||||||
log_interval: 100
|
|
||||||
checkpoint:
|
|
||||||
kbest_n: 50
|
|
||||||
latest_n: 5
|
|
||||||
|
|
||||||
|
|
||||||
decoding:
|
|
||||||
batch_size: 64
|
|
||||||
error_rate_type: wer
|
|
||||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
|
||||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
|
||||||
alpha: 2.5
|
|
||||||
beta: 0.3
|
|
||||||
beam_size: 10
|
|
||||||
cutoff_prob: 1.0
|
|
||||||
cutoff_top_n: 0
|
|
||||||
num_proc_bsearch: 8
|
|
||||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
|
||||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
|
||||||
# <0: for decoding, use full chunk.
|
|
||||||
# >0: for decoding, use fixed chunk size as set.
|
|
||||||
# 0: used for training, it's prohibited here.
|
|
||||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
|
||||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
|
||||||
|
|
||||||
|
|
@ -1,118 +0,0 @@
|
|||||||
# https://yaml.org/type/float.html
|
|
||||||
data:
|
|
||||||
train_manifest: data/manifest.train
|
|
||||||
dev_manifest: data/manifest.dev
|
|
||||||
test_manifest: data/manifest.test-clean
|
|
||||||
min_input_len: 0.5 # seconds
|
|
||||||
max_input_len: 20.0 # seconds
|
|
||||||
min_output_len: 0.0 # tokens
|
|
||||||
max_output_len: 400.0 # tokens
|
|
||||||
min_output_input_ratio: 0.05
|
|
||||||
max_output_input_ratio: 10.0
|
|
||||||
|
|
||||||
collator:
|
|
||||||
vocab_filepath: data/vocab.txt
|
|
||||||
unit_type: 'spm'
|
|
||||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
|
||||||
mean_std_filepath: ""
|
|
||||||
augmentation_config: conf/augmentation.json
|
|
||||||
batch_size: 16
|
|
||||||
raw_wav: True # use raw_wav or kaldi feature
|
|
||||||
spectrum_type: fbank #linear, mfcc, fbank
|
|
||||||
feat_dim: 80
|
|
||||||
delta_delta: False
|
|
||||||
dither: 1.0
|
|
||||||
target_sample_rate: 16000
|
|
||||||
max_freq: None
|
|
||||||
n_fft: None
|
|
||||||
stride_ms: 10.0
|
|
||||||
window_ms: 25.0
|
|
||||||
use_dB_normalization: True
|
|
||||||
target_dB: -20
|
|
||||||
random_seed: 0
|
|
||||||
keep_transcription_text: False
|
|
||||||
sortagrad: True
|
|
||||||
shuffle_method: batch_shuffle
|
|
||||||
num_workers: 2
|
|
||||||
|
|
||||||
|
|
||||||
# network architecture
|
|
||||||
model:
|
|
||||||
cmvn_file: "data/mean_std.json"
|
|
||||||
cmvn_file_type: "json"
|
|
||||||
# encoder related
|
|
||||||
encoder: conformer
|
|
||||||
encoder_conf:
|
|
||||||
output_size: 256 # dimension of attention
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048 # the number of units of position-wise feed forward
|
|
||||||
num_blocks: 12 # the number of encoder blocks
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
|
||||||
normalize_before: True
|
|
||||||
use_cnn_module: True
|
|
||||||
cnn_module_kernel: 15
|
|
||||||
activation_type: 'swish'
|
|
||||||
pos_enc_layer_type: 'rel_pos'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
|
|
||||||
# decoder related
|
|
||||||
decoder: transformer
|
|
||||||
decoder_conf:
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 6
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
self_attention_dropout_rate: 0.0
|
|
||||||
src_attention_dropout_rate: 0.0
|
|
||||||
|
|
||||||
# hybrid CTC/attention
|
|
||||||
model_conf:
|
|
||||||
ctc_weight: 0.3
|
|
||||||
ctc_dropoutrate: 0.0
|
|
||||||
ctc_grad_norm_type: null
|
|
||||||
lsm_weight: 0.1 # label smoothing option
|
|
||||||
length_normalized_loss: false
|
|
||||||
|
|
||||||
|
|
||||||
training:
|
|
||||||
n_epoch: 120
|
|
||||||
accum_grad: 8
|
|
||||||
global_grad_clip: 3.0
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.004
|
|
||||||
weight_decay: 1e-06
|
|
||||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
|
||||||
scheduler_conf:
|
|
||||||
warmup_steps: 25000
|
|
||||||
lr_decay: 1.0
|
|
||||||
log_interval: 100
|
|
||||||
checkpoint:
|
|
||||||
kbest_n: 50
|
|
||||||
latest_n: 5
|
|
||||||
|
|
||||||
|
|
||||||
decoding:
|
|
||||||
batch_size: 64
|
|
||||||
error_rate_type: wer
|
|
||||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
|
||||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
|
||||||
alpha: 2.5
|
|
||||||
beta: 0.3
|
|
||||||
beam_size: 10
|
|
||||||
cutoff_prob: 1.0
|
|
||||||
cutoff_top_n: 0
|
|
||||||
num_proc_bsearch: 8
|
|
||||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
|
||||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
|
||||||
# <0: for decoding, use full chunk.
|
|
||||||
# >0: for decoding, use fixed chunk size as set.
|
|
||||||
# 0: used for training, it's prohibited here.
|
|
||||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
|
||||||
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
|
||||||
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
|||||||
batchsize: 0
|
batchsize: 0
|
||||||
beam-size: 60
|
beam-size: 60
|
||||||
ctc-weight: 0.0
|
ctc-weight: 0.4
|
||||||
lm-weight: 0.0
|
lm-weight: 0.6
|
||||||
maxlenratio: 0.0
|
maxlenratio: 0.0
|
||||||
minlenratio: 0.0
|
minlenratio: 0.0
|
||||||
penalty: 0.0
|
penalty: 0.0
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
batchsize: 0
|
||||||
|
beam-size: 60
|
||||||
|
ctc-weight: 0.0
|
||||||
|
lm-weight: 0.0
|
||||||
|
maxlenratio: 0.0
|
||||||
|
minlenratio: 0.0
|
||||||
|
penalty: 0.0
|
@ -1,7 +1,7 @@
|
|||||||
batchsize: 0
|
batchsize: 0
|
||||||
beam-size: 60
|
beam-size: 60
|
||||||
ctc-weight: 0.4
|
ctc-weight: 0.4
|
||||||
lm-weight: 0.6
|
lm-weight: 0.0
|
||||||
maxlenratio: 0.0
|
maxlenratio: 0.0
|
||||||
minlenratio: 0.0
|
minlenratio: 0.0
|
||||||
penalty: 0.0
|
penalty: 0.0
|
@ -0,0 +1,2 @@
|
|||||||
|
--sample-frequency=16000
|
||||||
|
--num-mel-bins=80
|
@ -0,0 +1,13 @@
|
|||||||
|
model_module: transformer
|
||||||
|
model:
|
||||||
|
n_vocab: 5002
|
||||||
|
pos_enc: null
|
||||||
|
embed_unit: 128
|
||||||
|
att_unit: 512
|
||||||
|
head: 8
|
||||||
|
unit: 2048
|
||||||
|
layer: 16
|
||||||
|
dropout_rate: 0.5
|
||||||
|
emb_dropout_rate: 0.0
|
||||||
|
att_dropout_rate: 0.0
|
||||||
|
tie_weights: False
|
@ -0,0 +1 @@
|
|||||||
|
--sample-frequency=16000
|
@ -0,0 +1,85 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Copyright 2014 Vassil Panayotov
|
||||||
|
# 2014 Johns Hopkins University (author: Daniel Povey)
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
|
if [ "$#" -ne 2 ]; then
|
||||||
|
echo "Usage: $0 <src-dir> <dst-dir>"
|
||||||
|
echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
src=$1
|
||||||
|
dst=$2
|
||||||
|
|
||||||
|
# all utterances are FLAC compressed
|
||||||
|
if ! which flac >&/dev/null; then
|
||||||
|
echo "Please install 'flac' on ALL worker nodes!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
spk_file=$src/../SPEAKERS.TXT
|
||||||
|
|
||||||
|
mkdir -p $dst || exit 1
|
||||||
|
|
||||||
|
[ ! -d $src ] && echo "$0: no such directory $src" && exit 1
|
||||||
|
[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1
|
||||||
|
|
||||||
|
|
||||||
|
wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp
|
||||||
|
trans=$dst/text; [[ -f "$trans" ]] && rm $trans
|
||||||
|
utt2spk=$dst/utt2spk; [[ -f "$utt2spk" ]] && rm $utt2spk
|
||||||
|
spk2gender=$dst/spk2gender; [[ -f $spk2gender ]] && rm $spk2gender
|
||||||
|
|
||||||
|
for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do
|
||||||
|
reader=$(basename $reader_dir)
|
||||||
|
if ! [ $reader -eq $reader ]; then # not integer.
|
||||||
|
echo "$0: unexpected subdirectory name $reader"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
reader_gender=$(egrep "^$reader[ ]+\|" $spk_file | awk -F'|' '{gsub(/[ ]+/, ""); print tolower($2)}')
|
||||||
|
if [ "$reader_gender" != 'm' ] && [ "$reader_gender" != 'f' ]; then
|
||||||
|
echo "Unexpected gender: '$reader_gender'"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do
|
||||||
|
chapter=$(basename $chapter_dir)
|
||||||
|
if ! [ "$chapter" -eq "$chapter" ]; then
|
||||||
|
echo "$0: unexpected chapter-subdirectory name $chapter"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \
|
||||||
|
awk -v "dir=$chapter_dir" '{printf "%s flac -c -d -s %s/%s.flac |\n", $0, dir, $0}' >>$wav_scp|| exit 1
|
||||||
|
|
||||||
|
chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt
|
||||||
|
[ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1
|
||||||
|
cat $chapter_trans >>$trans
|
||||||
|
|
||||||
|
# NOTE: For now we are using per-chapter utt2spk. That is each chapter is considered
|
||||||
|
# to be a different speaker. This is done for simplicity and because we want
|
||||||
|
# e.g. the CMVN to be calculated per-chapter
|
||||||
|
awk -v "reader=$reader" -v "chapter=$chapter" '{printf "%s %s-%s\n", $1, reader, chapter}' \
|
||||||
|
<$chapter_trans >>$utt2spk || exit 1
|
||||||
|
|
||||||
|
# reader -> gender map (again using per-chapter granularity)
|
||||||
|
echo "${reader}-${chapter} $reader_gender" >>$spk2gender
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
spk2utt=$dst/spk2utt
|
||||||
|
utils/utt2spk_to_spk2utt.pl <$utt2spk >$spk2utt || exit 1
|
||||||
|
|
||||||
|
ntrans=$(wc -l <$trans)
|
||||||
|
nutt2spk=$(wc -l <$utt2spk)
|
||||||
|
! [ "$ntrans" -eq "$nutt2spk" ] && \
|
||||||
|
echo "Inconsistent #transcripts($ntrans) and #utt2spk($nutt2spk)" && exit 1
|
||||||
|
|
||||||
|
utils/validate_data_dir.sh --no-feats $dst || exit 1
|
||||||
|
|
||||||
|
echo "$0: successfully prepared data in $dst"
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1 @@
|
|||||||
|
../../../tools/kaldi/egs/wsj/s5/steps/
|
@ -1 +1 @@
|
|||||||
../../../utils/
|
../../../tools/kaldi/egs/wsj/s5/utils
|
@ -0,0 +1,7 @@
|
|||||||
|
.ipynb_checkpoints/**
|
||||||
|
*.ipynb
|
||||||
|
nohup.out
|
||||||
|
__pycache__/
|
||||||
|
*.wav
|
||||||
|
*.m4a
|
||||||
|
obsolete/**
|
@ -0,0 +1,45 @@
|
|||||||
|
repos:
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: yapf
|
||||||
|
name: yapf
|
||||||
|
entry: yapf
|
||||||
|
language: system
|
||||||
|
args: [-i, --style .style.yapf]
|
||||||
|
files: \.py$
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: a11d9314b22d8f8c7556443875b731ef05965464
|
||||||
|
hooks:
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: detect-private-key
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: check-added-large-files
|
||||||
|
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.8.0
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
name: isort (python)
|
||||||
|
- id: isort
|
||||||
|
name: isort (cython)
|
||||||
|
types: [cython]
|
||||||
|
- id: isort
|
||||||
|
name: isort (pyi)
|
||||||
|
types: [pyi]
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
name: flake8
|
||||||
|
entry: flake8
|
||||||
|
language: system
|
||||||
|
args:
|
||||||
|
- --count
|
||||||
|
- --select=E9,F63,F7,F82
|
||||||
|
- --show-source
|
||||||
|
- --statistics
|
||||||
|
files: \.py$
|
@ -0,0 +1,3 @@
|
|||||||
|
[style]
|
||||||
|
based_on_style = pep8
|
||||||
|
column_limit = 80
|
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
@ -0,0 +1,37 @@
|
|||||||
|
# PaddleAudio: The audio library for PaddlePaddle
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
PaddleAudio is the audio toolkit to speed up your audio research and development loop in PaddlePaddle. It currently provides a collection of audio datasets, feature-extraction functions, audio transforms,state-of-the-art pre-trained models in sound tagging/classification and anomaly sound detection. More models and features are on the roadmap.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Features
|
||||||
|
- Spectrogram and related features are compatible with librosa.
|
||||||
|
- State-of-the-art models in sound tagging on Audioset, sound classification on esc50, and more to come.
|
||||||
|
- Ready-to-use audio embedding with a line of code, includes sound embedding and more on the roadmap.
|
||||||
|
- Data loading supports for common open source audio in multiple languages including English, Mandarin and so on.
|
||||||
|
|
||||||
|
|
||||||
|
## Install
|
||||||
|
```
|
||||||
|
git clone https://github.com/PaddlePaddle/models
|
||||||
|
cd models/PaddleAudio
|
||||||
|
pip install .
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
### Audio loading and feature extraction
|
||||||
|
```
|
||||||
|
import paddleaudio as pa
|
||||||
|
s,r = pa.load(f)
|
||||||
|
mel_spect = pa.melspectrogram(s,sr=r)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
We provide a set of examples to help you get started in using PaddleAudio quickly.
|
||||||
|
- [PANNs: acoustic scene and events analysis using pre-trained models](./examples/panns)
|
||||||
|
- [Environmental Sound classification on ESC-50 dataset](./examples/sound_classification)
|
||||||
|
- [Training a audio-tagging network on Audioset](./examples/audioset_training)
|
||||||
|
|
||||||
|
Please refer to [example directory](./examples) for more details.
|
@ -0,0 +1,527 @@
|
|||||||
|
Speech
|
||||||
|
Male speech, man speaking
|
||||||
|
Female speech, woman speaking
|
||||||
|
Child speech, kid speaking
|
||||||
|
Conversation
|
||||||
|
Narration, monologue
|
||||||
|
Babbling
|
||||||
|
Speech synthesizer
|
||||||
|
Shout
|
||||||
|
Bellow
|
||||||
|
Whoop
|
||||||
|
Yell
|
||||||
|
Battle cry
|
||||||
|
Children shouting
|
||||||
|
Screaming
|
||||||
|
Whispering
|
||||||
|
Laughter
|
||||||
|
Baby laughter
|
||||||
|
Giggle
|
||||||
|
Snicker
|
||||||
|
Belly laugh
|
||||||
|
Chuckle, chortle
|
||||||
|
Crying, sobbing
|
||||||
|
Baby cry, infant cry
|
||||||
|
Whimper
|
||||||
|
Wail, moan
|
||||||
|
Sigh
|
||||||
|
Singing
|
||||||
|
Choir
|
||||||
|
Yodeling
|
||||||
|
Chant
|
||||||
|
Mantra
|
||||||
|
Male singing
|
||||||
|
Female singing
|
||||||
|
Child singing
|
||||||
|
Synthetic singing
|
||||||
|
Rapping
|
||||||
|
Humming
|
||||||
|
Groan
|
||||||
|
Grunt
|
||||||
|
Whistling
|
||||||
|
Breathing
|
||||||
|
Wheeze
|
||||||
|
Snoring
|
||||||
|
Gasp
|
||||||
|
Pant
|
||||||
|
Snort
|
||||||
|
Cough
|
||||||
|
Throat clearing
|
||||||
|
Sneeze
|
||||||
|
Sniff
|
||||||
|
Run
|
||||||
|
Shuffle
|
||||||
|
Walk, footsteps
|
||||||
|
Chewing, mastication
|
||||||
|
Biting
|
||||||
|
Gargling
|
||||||
|
Stomach rumble
|
||||||
|
Burping, eructation
|
||||||
|
Hiccup
|
||||||
|
Fart
|
||||||
|
Hands
|
||||||
|
Finger snapping
|
||||||
|
Clapping
|
||||||
|
Heart sounds, heartbeat
|
||||||
|
Heart murmur
|
||||||
|
Cheering
|
||||||
|
Applause
|
||||||
|
Chatter
|
||||||
|
Crowd
|
||||||
|
Hubbub, speech noise, speech babble
|
||||||
|
Children playing
|
||||||
|
Animal
|
||||||
|
Domestic animals, pets
|
||||||
|
Dog
|
||||||
|
Bark
|
||||||
|
Yip
|
||||||
|
Howl
|
||||||
|
Bow-wow
|
||||||
|
Growling
|
||||||
|
Whimper (dog)
|
||||||
|
Cat
|
||||||
|
Purr
|
||||||
|
Meow
|
||||||
|
Hiss
|
||||||
|
Caterwaul
|
||||||
|
Livestock, farm animals, working animals
|
||||||
|
Horse
|
||||||
|
Clip-clop
|
||||||
|
Neigh, whinny
|
||||||
|
Cattle, bovinae
|
||||||
|
Moo
|
||||||
|
Cowbell
|
||||||
|
Pig
|
||||||
|
Oink
|
||||||
|
Goat
|
||||||
|
Bleat
|
||||||
|
Sheep
|
||||||
|
Fowl
|
||||||
|
Chicken, rooster
|
||||||
|
Cluck
|
||||||
|
Crowing, cock-a-doodle-doo
|
||||||
|
Turkey
|
||||||
|
Gobble
|
||||||
|
Duck
|
||||||
|
Quack
|
||||||
|
Goose
|
||||||
|
Honk
|
||||||
|
Wild animals
|
||||||
|
Roaring cats (lions, tigers)
|
||||||
|
Roar
|
||||||
|
Bird
|
||||||
|
Bird vocalization, bird call, bird song
|
||||||
|
Chirp, tweet
|
||||||
|
Squawk
|
||||||
|
Pigeon, dove
|
||||||
|
Coo
|
||||||
|
Crow
|
||||||
|
Caw
|
||||||
|
Owl
|
||||||
|
Hoot
|
||||||
|
Bird flight, flapping wings
|
||||||
|
Canidae, dogs, wolves
|
||||||
|
Rodents, rats, mice
|
||||||
|
Mouse
|
||||||
|
Patter
|
||||||
|
Insect
|
||||||
|
Cricket
|
||||||
|
Mosquito
|
||||||
|
Fly, housefly
|
||||||
|
Buzz
|
||||||
|
Bee, wasp, etc.
|
||||||
|
Frog
|
||||||
|
Croak
|
||||||
|
Snake
|
||||||
|
Rattle
|
||||||
|
Whale vocalization
|
||||||
|
Music
|
||||||
|
Musical instrument
|
||||||
|
Plucked string instrument
|
||||||
|
Guitar
|
||||||
|
Electric guitar
|
||||||
|
Bass guitar
|
||||||
|
Acoustic guitar
|
||||||
|
Steel guitar, slide guitar
|
||||||
|
Tapping (guitar technique)
|
||||||
|
Strum
|
||||||
|
Banjo
|
||||||
|
Sitar
|
||||||
|
Mandolin
|
||||||
|
Zither
|
||||||
|
Ukulele
|
||||||
|
Keyboard (musical)
|
||||||
|
Piano
|
||||||
|
Electric piano
|
||||||
|
Organ
|
||||||
|
Electronic organ
|
||||||
|
Hammond organ
|
||||||
|
Synthesizer
|
||||||
|
Sampler
|
||||||
|
Harpsichord
|
||||||
|
Percussion
|
||||||
|
Drum kit
|
||||||
|
Drum machine
|
||||||
|
Drum
|
||||||
|
Snare drum
|
||||||
|
Rimshot
|
||||||
|
Drum roll
|
||||||
|
Bass drum
|
||||||
|
Timpani
|
||||||
|
Tabla
|
||||||
|
Cymbal
|
||||||
|
Hi-hat
|
||||||
|
Wood block
|
||||||
|
Tambourine
|
||||||
|
Rattle (instrument)
|
||||||
|
Maraca
|
||||||
|
Gong
|
||||||
|
Tubular bells
|
||||||
|
Mallet percussion
|
||||||
|
Marimba, xylophone
|
||||||
|
Glockenspiel
|
||||||
|
Vibraphone
|
||||||
|
Steelpan
|
||||||
|
Orchestra
|
||||||
|
Brass instrument
|
||||||
|
French horn
|
||||||
|
Trumpet
|
||||||
|
Trombone
|
||||||
|
Bowed string instrument
|
||||||
|
String section
|
||||||
|
Violin, fiddle
|
||||||
|
Pizzicato
|
||||||
|
Cello
|
||||||
|
Double bass
|
||||||
|
Wind instrument, woodwind instrument
|
||||||
|
Flute
|
||||||
|
Saxophone
|
||||||
|
Clarinet
|
||||||
|
Harp
|
||||||
|
Bell
|
||||||
|
Church bell
|
||||||
|
Jingle bell
|
||||||
|
Bicycle bell
|
||||||
|
Tuning fork
|
||||||
|
Chime
|
||||||
|
Wind chime
|
||||||
|
Change ringing (campanology)
|
||||||
|
Harmonica
|
||||||
|
Accordion
|
||||||
|
Bagpipes
|
||||||
|
Didgeridoo
|
||||||
|
Shofar
|
||||||
|
Theremin
|
||||||
|
Singing bowl
|
||||||
|
Scratching (performance technique)
|
||||||
|
Pop music
|
||||||
|
Hip hop music
|
||||||
|
Beatboxing
|
||||||
|
Rock music
|
||||||
|
Heavy metal
|
||||||
|
Punk rock
|
||||||
|
Grunge
|
||||||
|
Progressive rock
|
||||||
|
Rock and roll
|
||||||
|
Psychedelic rock
|
||||||
|
Rhythm and blues
|
||||||
|
Soul music
|
||||||
|
Reggae
|
||||||
|
Country
|
||||||
|
Swing music
|
||||||
|
Bluegrass
|
||||||
|
Funk
|
||||||
|
Folk music
|
||||||
|
Middle Eastern music
|
||||||
|
Jazz
|
||||||
|
Disco
|
||||||
|
Classical music
|
||||||
|
Opera
|
||||||
|
Electronic music
|
||||||
|
House music
|
||||||
|
Techno
|
||||||
|
Dubstep
|
||||||
|
Drum and bass
|
||||||
|
Electronica
|
||||||
|
Electronic dance music
|
||||||
|
Ambient music
|
||||||
|
Trance music
|
||||||
|
Music of Latin America
|
||||||
|
Salsa music
|
||||||
|
Flamenco
|
||||||
|
Blues
|
||||||
|
Music for children
|
||||||
|
New-age music
|
||||||
|
Vocal music
|
||||||
|
A capella
|
||||||
|
Music of Africa
|
||||||
|
Afrobeat
|
||||||
|
Christian music
|
||||||
|
Gospel music
|
||||||
|
Music of Asia
|
||||||
|
Carnatic music
|
||||||
|
Music of Bollywood
|
||||||
|
Ska
|
||||||
|
Traditional music
|
||||||
|
Independent music
|
||||||
|
Song
|
||||||
|
Background music
|
||||||
|
Theme music
|
||||||
|
Jingle (music)
|
||||||
|
Soundtrack music
|
||||||
|
Lullaby
|
||||||
|
Video game music
|
||||||
|
Christmas music
|
||||||
|
Dance music
|
||||||
|
Wedding music
|
||||||
|
Happy music
|
||||||
|
Funny music
|
||||||
|
Sad music
|
||||||
|
Tender music
|
||||||
|
Exciting music
|
||||||
|
Angry music
|
||||||
|
Scary music
|
||||||
|
Wind
|
||||||
|
Rustling leaves
|
||||||
|
Wind noise (microphone)
|
||||||
|
Thunderstorm
|
||||||
|
Thunder
|
||||||
|
Water
|
||||||
|
Rain
|
||||||
|
Raindrop
|
||||||
|
Rain on surface
|
||||||
|
Stream
|
||||||
|
Waterfall
|
||||||
|
Ocean
|
||||||
|
Waves, surf
|
||||||
|
Steam
|
||||||
|
Gurgling
|
||||||
|
Fire
|
||||||
|
Crackle
|
||||||
|
Vehicle
|
||||||
|
Boat, Water vehicle
|
||||||
|
Sailboat, sailing ship
|
||||||
|
Rowboat, canoe, kayak
|
||||||
|
Motorboat, speedboat
|
||||||
|
Ship
|
||||||
|
Motor vehicle (road)
|
||||||
|
Car
|
||||||
|
Vehicle horn, car horn, honking
|
||||||
|
Toot
|
||||||
|
Car alarm
|
||||||
|
Power windows, electric windows
|
||||||
|
Skidding
|
||||||
|
Tire squeal
|
||||||
|
Car passing by
|
||||||
|
Race car, auto racing
|
||||||
|
Truck
|
||||||
|
Air brake
|
||||||
|
Air horn, truck horn
|
||||||
|
Reversing beeps
|
||||||
|
Ice cream truck, ice cream van
|
||||||
|
Bus
|
||||||
|
Emergency vehicle
|
||||||
|
Police car (siren)
|
||||||
|
Ambulance (siren)
|
||||||
|
Fire engine, fire truck (siren)
|
||||||
|
Motorcycle
|
||||||
|
Traffic noise, roadway noise
|
||||||
|
Rail transport
|
||||||
|
Train
|
||||||
|
Train whistle
|
||||||
|
Train horn
|
||||||
|
Railroad car, train wagon
|
||||||
|
Train wheels squealing
|
||||||
|
Subway, metro, underground
|
||||||
|
Aircraft
|
||||||
|
Aircraft engine
|
||||||
|
Jet engine
|
||||||
|
Propeller, airscrew
|
||||||
|
Helicopter
|
||||||
|
Fixed-wing aircraft, airplane
|
||||||
|
Bicycle
|
||||||
|
Skateboard
|
||||||
|
Engine
|
||||||
|
Light engine (high frequency)
|
||||||
|
Dental drill, dentist's drill
|
||||||
|
Lawn mower
|
||||||
|
Chainsaw
|
||||||
|
Medium engine (mid frequency)
|
||||||
|
Heavy engine (low frequency)
|
||||||
|
Engine knocking
|
||||||
|
Engine starting
|
||||||
|
Idling
|
||||||
|
Accelerating, revving, vroom
|
||||||
|
Door
|
||||||
|
Doorbell
|
||||||
|
Ding-dong
|
||||||
|
Sliding door
|
||||||
|
Slam
|
||||||
|
Knock
|
||||||
|
Tap
|
||||||
|
Squeak
|
||||||
|
Cupboard open or close
|
||||||
|
Drawer open or close
|
||||||
|
Dishes, pots, and pans
|
||||||
|
Cutlery, silverware
|
||||||
|
Chopping (food)
|
||||||
|
Frying (food)
|
||||||
|
Microwave oven
|
||||||
|
Blender
|
||||||
|
Water tap, faucet
|
||||||
|
Sink (filling or washing)
|
||||||
|
Bathtub (filling or washing)
|
||||||
|
Hair dryer
|
||||||
|
Toilet flush
|
||||||
|
Toothbrush
|
||||||
|
Electric toothbrush
|
||||||
|
Vacuum cleaner
|
||||||
|
Zipper (clothing)
|
||||||
|
Keys jangling
|
||||||
|
Coin (dropping)
|
||||||
|
Scissors
|
||||||
|
Electric shaver, electric razor
|
||||||
|
Shuffling cards
|
||||||
|
Typing
|
||||||
|
Typewriter
|
||||||
|
Computer keyboard
|
||||||
|
Writing
|
||||||
|
Alarm
|
||||||
|
Telephone
|
||||||
|
Telephone bell ringing
|
||||||
|
Ringtone
|
||||||
|
Telephone dialing, DTMF
|
||||||
|
Dial tone
|
||||||
|
Busy signal
|
||||||
|
Alarm clock
|
||||||
|
Siren
|
||||||
|
Civil defense siren
|
||||||
|
Buzzer
|
||||||
|
Smoke detector, smoke alarm
|
||||||
|
Fire alarm
|
||||||
|
Foghorn
|
||||||
|
Whistle
|
||||||
|
Steam whistle
|
||||||
|
Mechanisms
|
||||||
|
Ratchet, pawl
|
||||||
|
Clock
|
||||||
|
Tick
|
||||||
|
Tick-tock
|
||||||
|
Gears
|
||||||
|
Pulleys
|
||||||
|
Sewing machine
|
||||||
|
Mechanical fan
|
||||||
|
Air conditioning
|
||||||
|
Cash register
|
||||||
|
Printer
|
||||||
|
Camera
|
||||||
|
Single-lens reflex camera
|
||||||
|
Tools
|
||||||
|
Hammer
|
||||||
|
Jackhammer
|
||||||
|
Sawing
|
||||||
|
Filing (rasp)
|
||||||
|
Sanding
|
||||||
|
Power tool
|
||||||
|
Drill
|
||||||
|
Explosion
|
||||||
|
Gunshot, gunfire
|
||||||
|
Machine gun
|
||||||
|
Fusillade
|
||||||
|
Artillery fire
|
||||||
|
Cap gun
|
||||||
|
Fireworks
|
||||||
|
Firecracker
|
||||||
|
Burst, pop
|
||||||
|
Eruption
|
||||||
|
Boom
|
||||||
|
Wood
|
||||||
|
Chop
|
||||||
|
Splinter
|
||||||
|
Crack
|
||||||
|
Glass
|
||||||
|
Chink, clink
|
||||||
|
Shatter
|
||||||
|
Liquid
|
||||||
|
Splash, splatter
|
||||||
|
Slosh
|
||||||
|
Squish
|
||||||
|
Drip
|
||||||
|
Pour
|
||||||
|
Trickle, dribble
|
||||||
|
Gush
|
||||||
|
Fill (with liquid)
|
||||||
|
Spray
|
||||||
|
Pump (liquid)
|
||||||
|
Stir
|
||||||
|
Boiling
|
||||||
|
Sonar
|
||||||
|
Arrow
|
||||||
|
Whoosh, swoosh, swish
|
||||||
|
Thump, thud
|
||||||
|
Thunk
|
||||||
|
Electronic tuner
|
||||||
|
Effects unit
|
||||||
|
Chorus effect
|
||||||
|
Basketball bounce
|
||||||
|
Bang
|
||||||
|
Slap, smack
|
||||||
|
Whack, thwack
|
||||||
|
Smash, crash
|
||||||
|
Breaking
|
||||||
|
Bouncing
|
||||||
|
Whip
|
||||||
|
Flap
|
||||||
|
Scratch
|
||||||
|
Scrape
|
||||||
|
Rub
|
||||||
|
Roll
|
||||||
|
Crushing
|
||||||
|
Crumpling, crinkling
|
||||||
|
Tearing
|
||||||
|
Beep, bleep
|
||||||
|
Ping
|
||||||
|
Ding
|
||||||
|
Clang
|
||||||
|
Squeal
|
||||||
|
Creak
|
||||||
|
Rustle
|
||||||
|
Whir
|
||||||
|
Clatter
|
||||||
|
Sizzle
|
||||||
|
Clicking
|
||||||
|
Clickety-clack
|
||||||
|
Rumble
|
||||||
|
Plop
|
||||||
|
Jingle, tinkle
|
||||||
|
Hum
|
||||||
|
Zing
|
||||||
|
Boing
|
||||||
|
Crunch
|
||||||
|
Silence
|
||||||
|
Sine wave
|
||||||
|
Harmonic
|
||||||
|
Chirp tone
|
||||||
|
Sound effect
|
||||||
|
Pulse
|
||||||
|
Inside, small room
|
||||||
|
Inside, large room or hall
|
||||||
|
Inside, public space
|
||||||
|
Outside, urban or manmade
|
||||||
|
Outside, rural or natural
|
||||||
|
Reverberation
|
||||||
|
Echo
|
||||||
|
Noise
|
||||||
|
Environmental noise
|
||||||
|
Static
|
||||||
|
Mains hum
|
||||||
|
Distortion
|
||||||
|
Sidetone
|
||||||
|
Cacophony
|
||||||
|
White noise
|
||||||
|
Pink noise
|
||||||
|
Throbbing
|
||||||
|
Vibration
|
||||||
|
Television
|
||||||
|
Radio
|
||||||
|
Field recording
|
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from paddleaudio.backends import load as load_audio
|
||||||
|
from paddleaudio.features import melspectrogram
|
||||||
|
from paddleaudio.models.panns import cnn14
|
||||||
|
from paddleaudio.utils import logger
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument('--device', choices=['cpu', 'gpu'], default='gpu', help='Select which device to predict, defaults to gpu.')
|
||||||
|
parser.add_argument('--wav', type=str, required=True, help='Audio file to infer.')
|
||||||
|
parser.add_argument('--sample_duration', type=float, default=2.0, help='Duration(in seconds) of tagging samples to predict.')
|
||||||
|
parser.add_argument('--hop_duration', type=float, default=0.3, help='Duration(in seconds) between two samples.')
|
||||||
|
parser.add_argument('--output_dir', type=str, default='./output_dir', help='Directory to save tagging result.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
|
def split(waveform: np.ndarray, win_size: int, hop_size: int):
|
||||||
|
"""
|
||||||
|
Split into N waveforms.
|
||||||
|
N is decided by win_size and hop_size.
|
||||||
|
"""
|
||||||
|
assert isinstance(waveform, np.ndarray)
|
||||||
|
time = []
|
||||||
|
data = []
|
||||||
|
for i in range(0, len(waveform), hop_size):
|
||||||
|
segment = waveform[i:i + win_size]
|
||||||
|
if len(segment) < win_size:
|
||||||
|
segment = np.pad(segment, (0, win_size - len(segment)))
|
||||||
|
data.append(segment)
|
||||||
|
time.append(i / len(waveform))
|
||||||
|
return time, data
|
||||||
|
|
||||||
|
|
||||||
|
def batchify(data: List[List[float]],
|
||||||
|
sample_rate: int,
|
||||||
|
batch_size: int,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Extract features from waveforms and create batches.
|
||||||
|
"""
|
||||||
|
examples = []
|
||||||
|
for waveform in data:
|
||||||
|
feats = melspectrogram(waveform, sample_rate, **kwargs).transpose()
|
||||||
|
examples.append(feats)
|
||||||
|
|
||||||
|
# Seperates data into some batches.
|
||||||
|
one_batch = []
|
||||||
|
for example in examples:
|
||||||
|
one_batch.append(example)
|
||||||
|
if len(one_batch) == batch_size:
|
||||||
|
yield one_batch
|
||||||
|
one_batch = []
|
||||||
|
if one_batch:
|
||||||
|
yield one_batch
|
||||||
|
|
||||||
|
|
||||||
|
def predict(model, data: List[List[float]], sample_rate: int,
|
||||||
|
batch_size: int=1):
|
||||||
|
"""
|
||||||
|
Use pretrained model to make predictions.
|
||||||
|
"""
|
||||||
|
batches = batchify(data, sample_rate, batch_size)
|
||||||
|
results = None
|
||||||
|
model.eval()
|
||||||
|
for batch in batches:
|
||||||
|
feats = paddle.to_tensor(batch).unsqueeze(1) \
|
||||||
|
# (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
|
||||||
|
|
||||||
|
audioset_scores = model(feats)
|
||||||
|
if results is None:
|
||||||
|
results = audioset_scores.numpy()
|
||||||
|
else:
|
||||||
|
results = np.concatenate((results, audioset_scores.numpy()))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
paddle.set_device(args.device)
|
||||||
|
model = cnn14(pretrained=True, extract_embedding=False)
|
||||||
|
waveform, sr = load_audio(args.wav, sr=None)
|
||||||
|
time, data = split(waveform,
|
||||||
|
int(args.sample_duration * sr),
|
||||||
|
int(args.hop_duration * sr))
|
||||||
|
results = predict(model, data, sr, batch_size=8)
|
||||||
|
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
time = np.arange(0, 1, int(args.hop_duration * sr) / len(waveform))
|
||||||
|
output_file = os.path.join(args.output_dir, f'audioset_tagging_sr_{sr}.npz')
|
||||||
|
np.savez(output_file, time=time, scores=results)
|
||||||
|
logger.info(f'Saved tagging results to {output_file}')
|
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from paddleaudio.utils import logger
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument('--tagging_file', type=str, required=True, help='')
|
||||||
|
parser.add_argument('--top_k', type=int, default=10, help='Get top k predicted results of audioset labels.')
|
||||||
|
parser.add_argument('--smooth', type=ast.literal_eval, default=True, help='Set "True" to apply posterior smoothing.')
|
||||||
|
parser.add_argument('--smooth_size', type=int, default=5, help='Window size of posterior smoothing.')
|
||||||
|
parser.add_argument('--label_file', type=str, default='./assets/audioset_labels.txt', help='File of audioset labels.')
|
||||||
|
parser.add_argument('--output_dir', type=str, default='./output_dir', help='Directory to save tagging labels.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
|
def smooth(results: np.ndarray, win_size: int):
|
||||||
|
"""
|
||||||
|
Execute posterior smoothing in-place.
|
||||||
|
"""
|
||||||
|
for i in range(len(results) - 1, -1, -1):
|
||||||
|
if i < win_size - 1:
|
||||||
|
left = 0
|
||||||
|
else:
|
||||||
|
left = i + 1 - win_size
|
||||||
|
results[i] = np.sum(results[left:i + 1], axis=0) / (i - left + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_topk_label(k: int, label_map: Dict, result: np.ndarray):
|
||||||
|
"""
|
||||||
|
Return top k result.
|
||||||
|
"""
|
||||||
|
result = np.asarray(result)
|
||||||
|
topk_idx = (-result).argsort()[:k]
|
||||||
|
|
||||||
|
ret = ''
|
||||||
|
for idx in topk_idx:
|
||||||
|
label, score = label_map[idx], result[idx]
|
||||||
|
ret += f'{label}: {score}\n'
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
label_map = {}
|
||||||
|
with open(args.label_file, 'r') as f:
|
||||||
|
for i, l in enumerate(f.readlines()):
|
||||||
|
label_map[i] = l.strip()
|
||||||
|
|
||||||
|
results = np.load(args.tagging_file, allow_pickle=True)
|
||||||
|
times, scores = results['time'], results['scores']
|
||||||
|
|
||||||
|
if args.smooth:
|
||||||
|
logger.info('Posterior smoothing...')
|
||||||
|
smooth(scores, win_size=args.smooth_size)
|
||||||
|
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
output_file = os.path.join(
|
||||||
|
args.output_dir,
|
||||||
|
os.path.basename(args.tagging_file).split('.')[0] + '.txt')
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
for time, score in zip(times, scores):
|
||||||
|
f.write(f'{time}\n')
|
||||||
|
f.write(generate_topk_label(args.top_k, label_map, score) + '\n')
|
||||||
|
|
||||||
|
logger.info(f'Saved tagging labels to {output_file}')
|
@ -0,0 +1,147 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from paddle import inference
|
||||||
|
from scipy.special import softmax
|
||||||
|
|
||||||
|
from paddleaudio.backends import load as load_audio
|
||||||
|
from paddleaudio.datasets import ESC50
|
||||||
|
from paddleaudio.features import melspectrogram
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=2, help="Batch size per GPU/CPU for training.")
|
||||||
|
parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
|
||||||
|
parser.add_argument('--use_tensorrt', type=eval, default=False, choices=[True, False], help='Enable to use tensorrt to speed up.')
|
||||||
|
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16"], help='The tensorrt precision.')
|
||||||
|
parser.add_argument('--cpu_threads', type=int, default=10, help='Number of threads to predict when using cpu.')
|
||||||
|
parser.add_argument('--enable_mkldnn', type=eval, default=False, choices=[True, False], help='Enable to use mkldnn to speed up when using cpu.')
|
||||||
|
parser.add_argument("--log_dir", type=str, default="./log", help="The path to save log.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(files: str, **kwargs):
|
||||||
|
waveforms = []
|
||||||
|
srs = []
|
||||||
|
max_length = float('-inf')
|
||||||
|
for file in files:
|
||||||
|
waveform, sr = load_audio(file, sr=None)
|
||||||
|
max_length = max(max_length, len(waveform))
|
||||||
|
waveforms.append(waveform)
|
||||||
|
srs.append(sr)
|
||||||
|
|
||||||
|
feats = []
|
||||||
|
for i in range(len(waveforms)):
|
||||||
|
# padding
|
||||||
|
if len(waveforms[i]) < max_length:
|
||||||
|
pad_width = max_length - len(waveforms[i])
|
||||||
|
waveforms[i] = np.pad(waveforms[i], pad_width=(0, pad_width))
|
||||||
|
|
||||||
|
feat = melspectrogram(waveforms[i], sr, **kwargs).transpose()
|
||||||
|
feats.append(feat)
|
||||||
|
|
||||||
|
return np.stack(feats, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
class Predictor(object):
|
||||||
|
def __init__(self,
|
||||||
|
model_dir,
|
||||||
|
device="gpu",
|
||||||
|
batch_size=1,
|
||||||
|
use_tensorrt=False,
|
||||||
|
precision="fp32",
|
||||||
|
cpu_threads=10,
|
||||||
|
enable_mkldnn=False):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
model_file = os.path.join(model_dir, "inference.pdmodel")
|
||||||
|
params_file = os.path.join(model_dir, "inference.pdiparams")
|
||||||
|
|
||||||
|
assert os.path.isfile(model_file) and os.path.isfile(
|
||||||
|
params_file), 'Please check model and parameter files.'
|
||||||
|
|
||||||
|
config = inference.Config(model_file, params_file)
|
||||||
|
if device == "gpu":
|
||||||
|
# set GPU configs accordingly
|
||||||
|
# such as intialize the gpu memory, enable tensorrt
|
||||||
|
config.enable_use_gpu(100, 0)
|
||||||
|
precision_map = {
|
||||||
|
"fp16": inference.PrecisionType.Half,
|
||||||
|
"fp32": inference.PrecisionType.Float32,
|
||||||
|
}
|
||||||
|
precision_mode = precision_map[precision]
|
||||||
|
|
||||||
|
if use_tensorrt:
|
||||||
|
config.enable_tensorrt_engine(
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
min_subgraph_size=30,
|
||||||
|
precision_mode=precision_mode)
|
||||||
|
elif device == "cpu":
|
||||||
|
# set CPU configs accordingly,
|
||||||
|
# such as enable_mkldnn, set_cpu_math_library_num_threads
|
||||||
|
config.disable_gpu()
|
||||||
|
if enable_mkldnn:
|
||||||
|
# cache 10 different shapes for mkldnn to avoid memory leak
|
||||||
|
config.set_mkldnn_cache_capacity(10)
|
||||||
|
config.enable_mkldnn()
|
||||||
|
config.set_cpu_math_library_num_threads(cpu_threads)
|
||||||
|
elif device == "xpu":
|
||||||
|
# set XPU configs accordingly
|
||||||
|
config.enable_xpu(100)
|
||||||
|
|
||||||
|
config.switch_use_feed_fetch_ops(False)
|
||||||
|
self.predictor = inference.create_predictor(config)
|
||||||
|
self.input_handles = [
|
||||||
|
self.predictor.get_input_handle(name)
|
||||||
|
for name in self.predictor.get_input_names()
|
||||||
|
]
|
||||||
|
self.output_handle = self.predictor.get_output_handle(
|
||||||
|
self.predictor.get_output_names()[0])
|
||||||
|
|
||||||
|
def predict(self, wavs):
|
||||||
|
feats = extract_features(wavs)
|
||||||
|
|
||||||
|
self.input_handles[0].copy_from_cpu(feats)
|
||||||
|
self.predictor.run()
|
||||||
|
logits = self.output_handle.copy_to_cpu()
|
||||||
|
probs = softmax(logits, axis=1)
|
||||||
|
indices = np.argmax(probs, axis=1)
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Define predictor to do prediction.
|
||||||
|
predictor = Predictor(args.model_dir, args.device, args.batch_size,
|
||||||
|
args.use_tensorrt, args.precision, args.cpu_threads,
|
||||||
|
args.enable_mkldnn)
|
||||||
|
|
||||||
|
wavs = [
|
||||||
|
'~/audio_demo_resource/cat.wav',
|
||||||
|
'~/audio_demo_resource/dog.wav',
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(len(wavs)):
|
||||||
|
wavs[i] = os.path.abspath(os.path.expanduser(wavs[i]))
|
||||||
|
assert os.path.isfile(
|
||||||
|
wavs[i]), f'Please check input wave file: {wavs[i]}'
|
||||||
|
|
||||||
|
results = predictor.predict(wavs)
|
||||||
|
for idx, wav in enumerate(wavs):
|
||||||
|
print(f'Wav: {wav} \t Label: {ESC50.label_list[results[idx]]}')
|
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from model import SoundClassifier
|
||||||
|
|
||||||
|
from paddleaudio.datasets import ESC50
|
||||||
|
from paddleaudio.models.panns import cnn14
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint of model.")
|
||||||
|
parser.add_argument("--output_dir", type=str, default='./export', help="Path to save static model and its parameters.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = SoundClassifier(
|
||||||
|
backbone=cnn14(pretrained=False, extract_embedding=True),
|
||||||
|
num_class=len(ESC50.label_list))
|
||||||
|
model.set_state_dict(paddle.load(args.checkpoint))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model = paddle.jit.to_static(
|
||||||
|
model,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, None, 64], dtype=paddle.float32)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Save in static graph model.
|
||||||
|
paddle.jit.save(model, os.path.join(args.output_dir, "inference"))
|
@ -0,0 +1,36 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class SoundClassifier(nn.Layer):
|
||||||
|
"""
|
||||||
|
Model for sound classification which uses panns pretrained models to extract
|
||||||
|
embeddings from audio files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backbone, num_class, dropout=0.1):
|
||||||
|
super(SoundClassifier, self).__init__()
|
||||||
|
self.backbone = backbone
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.fc = nn.Linear(self.backbone.emb_size, num_class)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = self.backbone(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
logits = self.fc(x)
|
||||||
|
|
||||||
|
return logits
|
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from model import SoundClassifier
|
||||||
|
|
||||||
|
from paddleaudio.backends import load as load_audio
|
||||||
|
from paddleaudio.datasets import ESC50
|
||||||
|
from paddleaudio.features import melspectrogram
|
||||||
|
from paddleaudio.models.panns import cnn14
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to predict, defaults to gpu.")
|
||||||
|
parser.add_argument("--wav", type=str, required=True, help="Audio file to infer.")
|
||||||
|
parser.add_argument("--top_k", type=int, default=1, help="Show top k predicted results")
|
||||||
|
parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint of model.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(file: str, **kwargs):
|
||||||
|
waveform, sr = load_audio(file, sr=None)
|
||||||
|
feat = melspectrogram(waveform, sr, **kwargs).transpose()
|
||||||
|
return feat
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
paddle.set_device(args.device)
|
||||||
|
|
||||||
|
model = SoundClassifier(
|
||||||
|
backbone=cnn14(pretrained=False, extract_embedding=True),
|
||||||
|
num_class=len(ESC50.label_list))
|
||||||
|
model.set_state_dict(paddle.load(args.checkpoint))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
feat = np.expand_dims(extract_features(args.wav), 0)
|
||||||
|
feat = paddle.to_tensor(feat)
|
||||||
|
logits = model(feat)
|
||||||
|
probs = F.softmax(logits, axis=1).numpy()
|
||||||
|
|
||||||
|
sorted_indices = (-probs[0]).argsort()
|
||||||
|
|
||||||
|
msg = f'[{args.wav}]\n'
|
||||||
|
for idx in sorted_indices[:args.top_k]:
|
||||||
|
msg += f'{ESC50.label_list[idx]}: {probs[0][idx]}\n'
|
||||||
|
print(msg)
|
@ -0,0 +1,149 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from model import SoundClassifier
|
||||||
|
|
||||||
|
from paddleaudio.datasets import ESC50
|
||||||
|
from paddleaudio.models.panns import cnn14
|
||||||
|
from paddleaudio.utils import logger
|
||||||
|
from paddleaudio.utils import Timer
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
|
||||||
|
parser.add_argument("--epochs", type=int, default=50, help="Number of epoches for fine-tuning.")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers in dataloader.")
|
||||||
|
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to save model checkpoints.")
|
||||||
|
parser.add_argument("--save_freq", type=int, default=10, help="Save checkpoint every n epoch.")
|
||||||
|
parser.add_argument("--log_freq", type=int, default=10, help="Log the training infomation every n steps.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
paddle.set_device(args.device)
|
||||||
|
nranks = paddle.distributed.get_world_size()
|
||||||
|
if paddle.distributed.get_world_size() > 1:
|
||||||
|
paddle.distributed.init_parallel_env()
|
||||||
|
local_rank = paddle.distributed.get_rank()
|
||||||
|
|
||||||
|
backbone = cnn14(pretrained=True, extract_embedding=True)
|
||||||
|
model = SoundClassifier(backbone, num_class=len(ESC50.label_list))
|
||||||
|
model = paddle.DataParallel(model)
|
||||||
|
optimizer = paddle.optimizer.Adam(
|
||||||
|
learning_rate=args.learning_rate, parameters=model.parameters())
|
||||||
|
criterion = paddle.nn.loss.CrossEntropyLoss()
|
||||||
|
|
||||||
|
train_ds = ESC50(mode='train', feat_type='melspectrogram')
|
||||||
|
dev_ds = ESC50(mode='dev', feat_type='melspectrogram')
|
||||||
|
|
||||||
|
train_sampler = paddle.io.DistributedBatchSampler(
|
||||||
|
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
|
||||||
|
train_loader = paddle.io.DataLoader(
|
||||||
|
train_ds,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
return_list=True,
|
||||||
|
use_buffer_reader=True, )
|
||||||
|
|
||||||
|
steps_per_epoch = len(train_sampler)
|
||||||
|
timer = Timer(steps_per_epoch * args.epochs)
|
||||||
|
timer.start()
|
||||||
|
|
||||||
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
avg_loss = 0
|
||||||
|
num_corrects = 0
|
||||||
|
num_samples = 0
|
||||||
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
|
feats, labels = batch
|
||||||
|
logits = model(feats)
|
||||||
|
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
if isinstance(optimizer._learning_rate,
|
||||||
|
paddle.optimizer.lr.LRScheduler):
|
||||||
|
optimizer._learning_rate.step()
|
||||||
|
optimizer.clear_grad()
|
||||||
|
|
||||||
|
# Calculate loss
|
||||||
|
avg_loss += loss.numpy()[0]
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
preds = paddle.argmax(logits, axis=1)
|
||||||
|
num_corrects += (preds == labels).numpy().sum()
|
||||||
|
num_samples += feats.shape[0]
|
||||||
|
|
||||||
|
timer.count()
|
||||||
|
|
||||||
|
if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0:
|
||||||
|
lr = optimizer.get_lr()
|
||||||
|
avg_loss /= args.log_freq
|
||||||
|
avg_acc = num_corrects / num_samples
|
||||||
|
|
||||||
|
print_msg = 'Epoch={}/{}, Step={}/{}'.format(
|
||||||
|
epoch, args.epochs, batch_idx + 1, steps_per_epoch)
|
||||||
|
print_msg += ' loss={:.4f}'.format(avg_loss)
|
||||||
|
print_msg += ' acc={:.4f}'.format(avg_acc)
|
||||||
|
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(
|
||||||
|
lr, timer.timing, timer.eta)
|
||||||
|
logger.train(print_msg)
|
||||||
|
|
||||||
|
avg_loss = 0
|
||||||
|
num_corrects = 0
|
||||||
|
num_samples = 0
|
||||||
|
|
||||||
|
if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
|
||||||
|
dev_sampler = paddle.io.BatchSampler(
|
||||||
|
dev_ds,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False)
|
||||||
|
dev_loader = paddle.io.DataLoader(
|
||||||
|
dev_ds,
|
||||||
|
batch_sampler=dev_sampler,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
return_list=True, )
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
num_corrects = 0
|
||||||
|
num_samples = 0
|
||||||
|
with logger.processing('Evaluation on validation dataset'):
|
||||||
|
for batch_idx, batch in enumerate(dev_loader):
|
||||||
|
feats, labels = batch
|
||||||
|
logits = model(feats)
|
||||||
|
|
||||||
|
preds = paddle.argmax(logits, axis=1)
|
||||||
|
num_corrects += (preds == labels).numpy().sum()
|
||||||
|
num_samples += feats.shape[0]
|
||||||
|
|
||||||
|
print_msg = '[Evaluation result]'
|
||||||
|
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)
|
||||||
|
|
||||||
|
logger.eval(print_msg)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
save_dir = os.path.join(args.checkpoint_dir,
|
||||||
|
'epoch_{}'.format(epoch))
|
||||||
|
logger.info('Saving model checkpoint to {}'.format(save_dir))
|
||||||
|
paddle.save(model.state_dict(),
|
||||||
|
os.path.join(save_dir, 'model.pdparams'))
|
||||||
|
paddle.save(optimizer.state_dict(),
|
||||||
|
os.path.join(save_dir, 'model.pdopt'))
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .backends import *
|
||||||
|
from .features import *
|
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .audio import *
|
@ -0,0 +1,303 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import warnings
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import resampy
|
||||||
|
import soundfile as sf
|
||||||
|
from numpy import ndarray as array
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
from ..utils import ParameterError
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'resample',
|
||||||
|
'to_mono',
|
||||||
|
'depth_convert',
|
||||||
|
'normalize',
|
||||||
|
'save_wav',
|
||||||
|
'load',
|
||||||
|
]
|
||||||
|
NORMALMIZE_TYPES = ['linear', 'gaussian']
|
||||||
|
MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
|
||||||
|
RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
|
||||||
|
EPS = 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
def resample(y: array, src_sr: int, target_sr: int,
|
||||||
|
mode: str='kaiser_fast') -> array:
|
||||||
|
""" Audio resampling
|
||||||
|
|
||||||
|
This function is the same as using resampy.resample().
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast'
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if mode == 'kaiser_best':
|
||||||
|
warnings.warn(
|
||||||
|
f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
|
||||||
|
we recommend the mode kaiser_fast in large scale audio trainning')
|
||||||
|
|
||||||
|
if not isinstance(y, np.ndarray):
|
||||||
|
raise ParameterError(
|
||||||
|
'Only support numpy array, but received y in {type(y)}')
|
||||||
|
|
||||||
|
if mode not in RESAMPLE_MODES:
|
||||||
|
raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
|
||||||
|
|
||||||
|
return resampy.resample(y, src_sr, target_sr, filter=mode)
|
||||||
|
|
||||||
|
|
||||||
|
def to_mono(y: array, merge_type: str='average') -> array:
|
||||||
|
""" convert sterior audio to mono
|
||||||
|
"""
|
||||||
|
if merge_type not in MERGE_TYPES:
|
||||||
|
raise ParameterError(
|
||||||
|
f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
|
||||||
|
)
|
||||||
|
if y.ndim > 2:
|
||||||
|
raise ParameterError(
|
||||||
|
f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}')
|
||||||
|
if y.ndim == 1: # nothing to merge
|
||||||
|
return y
|
||||||
|
|
||||||
|
if merge_type == 'ch0':
|
||||||
|
return y[0]
|
||||||
|
if merge_type == 'ch1':
|
||||||
|
return y[1]
|
||||||
|
if merge_type == 'random':
|
||||||
|
return y[np.random.randint(0, 2)]
|
||||||
|
|
||||||
|
# need to do averaging according to dtype
|
||||||
|
|
||||||
|
if y.dtype == 'float32':
|
||||||
|
y_out = (y[0] + y[1]) * 0.5
|
||||||
|
elif y.dtype == 'int16':
|
||||||
|
y_out = y.astype('int32')
|
||||||
|
y_out = (y_out[0] + y_out[1]) // 2
|
||||||
|
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
|
||||||
|
np.iinfo(y.dtype).max).astype(y.dtype)
|
||||||
|
|
||||||
|
elif y.dtype == 'int8':
|
||||||
|
y_out = y.astype('int16')
|
||||||
|
y_out = (y_out[0] + y_out[1]) // 2
|
||||||
|
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
|
||||||
|
np.iinfo(y.dtype).max).astype(y.dtype)
|
||||||
|
else:
|
||||||
|
raise ParameterError(f'Unsupported dtype: {y.dtype}')
|
||||||
|
return y_out
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_cast(y: array, dtype: Union[type, str]) -> array:
|
||||||
|
""" data type casting in a safe way, i.e., prevent overflow or underflow
|
||||||
|
|
||||||
|
This function is used internally.
|
||||||
|
"""
|
||||||
|
return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def depth_convert(y: array, dtype: Union[type, str],
|
||||||
|
dithering: bool=True) -> array:
|
||||||
|
"""Convert audio array to target dtype safely
|
||||||
|
|
||||||
|
This function convert audio waveform to a target dtype, with addition steps of
|
||||||
|
preventing overflow/underflow and preserving audio range.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64']
|
||||||
|
if y.dtype not in SUPPORT_DTYPE:
|
||||||
|
raise ParameterError(
|
||||||
|
'Unsupported audio dtype, '
|
||||||
|
f'y.dtype is {y.dtype}, supported dtypes are {SUPPORT_DTYPE}')
|
||||||
|
|
||||||
|
if dtype not in SUPPORT_DTYPE:
|
||||||
|
raise ParameterError(
|
||||||
|
'Unsupported audio dtype, '
|
||||||
|
f'target dtype is {dtype}, supported dtypes are {SUPPORT_DTYPE}')
|
||||||
|
|
||||||
|
if dtype == y.dtype:
|
||||||
|
return y
|
||||||
|
|
||||||
|
if dtype == 'float64' and y.dtype == 'float32':
|
||||||
|
return _safe_cast(y, dtype)
|
||||||
|
if dtype == 'float32' and y.dtype == 'float64':
|
||||||
|
return _safe_cast(y, dtype)
|
||||||
|
|
||||||
|
if dtype == 'int16' or dtype == 'int8':
|
||||||
|
if y.dtype in ['float64', 'float32']:
|
||||||
|
factor = np.iinfo(dtype).max
|
||||||
|
y = np.clip(y * factor, np.iinfo(dtype).min,
|
||||||
|
np.iinfo(dtype).max).astype(dtype)
|
||||||
|
y = y.astype(dtype)
|
||||||
|
else:
|
||||||
|
if dtype == 'int16' and y.dtype == 'int8':
|
||||||
|
factor = np.iinfo('int16').max / np.iinfo('int8').max - EPS
|
||||||
|
y = y.astype('float32') * factor
|
||||||
|
y = y.astype('int16')
|
||||||
|
|
||||||
|
else: # dtype == 'int8' and y.dtype=='int16':
|
||||||
|
y = y.astype('int32') * np.iinfo('int8').max / \
|
||||||
|
np.iinfo('int16').max
|
||||||
|
y = y.astype('int8')
|
||||||
|
|
||||||
|
if dtype in ['float32', 'float64']:
|
||||||
|
org_dtype = y.dtype
|
||||||
|
y = y.astype(dtype) / np.iinfo(org_dtype).max
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def sound_file_load(file: str,
|
||||||
|
offset: Optional[float]=None,
|
||||||
|
dtype: str='int16',
|
||||||
|
duration: Optional[int]=None) -> Tuple[array, int]:
|
||||||
|
"""Load audio using soundfile library
|
||||||
|
|
||||||
|
This function load audio file using libsndfile.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
http://www.mega-nerd.com/libsndfile/#Features
|
||||||
|
|
||||||
|
"""
|
||||||
|
with sf.SoundFile(file) as sf_desc:
|
||||||
|
sr_native = sf_desc.samplerate
|
||||||
|
if offset:
|
||||||
|
sf_desc.seek(int(offset * sr_native))
|
||||||
|
if duration is not None:
|
||||||
|
frame_duration = int(duration * sr_native)
|
||||||
|
else:
|
||||||
|
frame_duration = -1
|
||||||
|
y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T
|
||||||
|
|
||||||
|
return y, sf_desc.samplerate
|
||||||
|
|
||||||
|
|
||||||
|
def audio_file_load():
|
||||||
|
"""Load audio using audiofile library
|
||||||
|
|
||||||
|
This function load audio file using audiofile.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
https://audiofile.68k.org/
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def sox_file_load():
|
||||||
|
"""Load audio using sox library
|
||||||
|
|
||||||
|
This function load audio file using sox.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
http://sox.sourceforge.net/
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(y: array, norm_type: str='linear',
|
||||||
|
mul_factor: float=1.0) -> array:
|
||||||
|
""" normalize an input audio with additional multiplier.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if norm_type == 'linear':
|
||||||
|
amax = np.max(np.abs(y))
|
||||||
|
factor = 1.0 / (amax + EPS)
|
||||||
|
y = y * factor * mul_factor
|
||||||
|
elif norm_type == 'gaussian':
|
||||||
|
amean = np.mean(y)
|
||||||
|
astd = np.std(y)
|
||||||
|
astd = max(astd, EPS)
|
||||||
|
y = mul_factor * (y - amean) / astd
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}')
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def save_wav(y: array, sr: int, file: str) -> None:
|
||||||
|
"""Save audio file to disk.
|
||||||
|
This function saves audio to disk using scipy.io.wavfile, with additional step
|
||||||
|
to convert input waveform to int16 unless it already is int16
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
It only support raw wav format.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not file.endswith('.wav'):
|
||||||
|
raise ParameterError(
|
||||||
|
f'only .wav file supported, but dst file name is: {file}')
|
||||||
|
|
||||||
|
if sr <= 0:
|
||||||
|
raise ParameterError(
|
||||||
|
f'Sample rate should be larger than 0, recieved sr = {sr}')
|
||||||
|
|
||||||
|
if y.dtype not in ['int16', 'int8']:
|
||||||
|
warnings.warn(
|
||||||
|
f'input data type is {y.dtype}, will convert data to int16 format before saving'
|
||||||
|
)
|
||||||
|
y_out = depth_convert(y, 'int16')
|
||||||
|
else:
|
||||||
|
y_out = y
|
||||||
|
|
||||||
|
wavfile.write(file, sr, y_out)
|
||||||
|
|
||||||
|
|
||||||
|
def load(
|
||||||
|
file: str,
|
||||||
|
sr: Optional[int]=None,
|
||||||
|
mono: bool=True,
|
||||||
|
merge_type: str='average', # ch0,ch1,random,average
|
||||||
|
normal: bool=True,
|
||||||
|
norm_type: str='linear',
|
||||||
|
norm_mul_factor: float=1.0,
|
||||||
|
offset: float=0.0,
|
||||||
|
duration: Optional[int]=None,
|
||||||
|
dtype: str='float32',
|
||||||
|
resample_mode: str='kaiser_fast') -> Tuple[array, int]:
|
||||||
|
"""Load audio file from disk.
|
||||||
|
This function loads audio from disk using using audio beackend.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration)
|
||||||
|
|
||||||
|
if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)):
|
||||||
|
raise ParameterError(f'audio file {file} looks empty')
|
||||||
|
|
||||||
|
if mono:
|
||||||
|
y = to_mono(y, merge_type)
|
||||||
|
|
||||||
|
if sr is not None and sr != r:
|
||||||
|
y = resample(y, r, sr, mode=resample_mode)
|
||||||
|
r = sr
|
||||||
|
|
||||||
|
if normal:
|
||||||
|
y = normalize(y, norm_type, norm_mul_factor)
|
||||||
|
elif dtype in ['int8', 'int16']:
|
||||||
|
# still need to do normalization, before depth convertion
|
||||||
|
y = normalize(y, 'linear', 1.0)
|
||||||
|
|
||||||
|
y = depth_convert(y, dtype)
|
||||||
|
return y, r
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .aishell import AISHELL1
|
||||||
|
from .dcase import UrbanAcousticScenes
|
||||||
|
from .dcase import UrbanAudioVisualScenes
|
||||||
|
from .esc50 import ESC50
|
||||||
|
from .gtzan import GTZAN
|
||||||
|
from .librispeech import LIBRISPEECH
|
||||||
|
from .ravdess import RAVDESS
|
||||||
|
from .tess import TESS
|
||||||
|
from .urban_sound import UrbanSound8K
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'AISHELL1',
|
||||||
|
'LIBRISPEECH',
|
||||||
|
'ESC50',
|
||||||
|
'UrbanSound8K',
|
||||||
|
'GTZAN',
|
||||||
|
'UrbanAcousticScenes',
|
||||||
|
'UrbanAudioVisualScenes',
|
||||||
|
'RAVDESS',
|
||||||
|
'TESS',
|
||||||
|
]
|
@ -0,0 +1,154 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import codecs
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from paddle.io import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..backends import load as load_audio
|
||||||
|
from ..utils.download import decompress
|
||||||
|
from ..utils.download import download_and_decompress
|
||||||
|
from ..utils.env import DATA_HOME
|
||||||
|
from ..utils.log import logger
|
||||||
|
from .dataset import feat_funcs
|
||||||
|
|
||||||
|
__all__ = ['AISHELL1']
|
||||||
|
|
||||||
|
|
||||||
|
class AISHELL1(Dataset):
|
||||||
|
"""
|
||||||
|
This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long.
|
||||||
|
It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including
|
||||||
|
smart home, autonomous driving, and industrial production. The whole recording was
|
||||||
|
put in quiet indoor environment, using 3 different devices at the same time: high
|
||||||
|
fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit),
|
||||||
|
iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled
|
||||||
|
to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas
|
||||||
|
in China were invited to participate in the recording. The manual transcription
|
||||||
|
accuracy rate is above 95%, through professional speech annotation and strict
|
||||||
|
quality inspection. The corpus is divided into training, development and testing
|
||||||
|
sets.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
AISHELL-1: An Open-Source Mandarin Speech Corpus and A Speech Recognition Baseline
|
||||||
|
https://arxiv.org/abs/1709.05522
|
||||||
|
"""
|
||||||
|
|
||||||
|
archieves = [
|
||||||
|
{
|
||||||
|
'url': 'http://www.openslr.org/resources/33/data_aishell.tgz',
|
||||||
|
'md5': '2f494334227864a8a8fec932999db9d8',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
text_meta = os.path.join('data_aishell', 'transcript',
|
||||||
|
'aishell_transcript_v0.8.txt')
|
||||||
|
utt_info = collections.namedtuple('META_INFO',
|
||||||
|
('file_path', 'utt_id', 'text'))
|
||||||
|
audio_path = os.path.join('data_aishell', 'wav')
|
||||||
|
manifest_path = os.path.join('data_aishell', 'manifest')
|
||||||
|
subset = ['train', 'dev', 'test']
|
||||||
|
|
||||||
|
def __init__(self, subset: str='train', feat_type: str='raw', **kwargs):
|
||||||
|
assert subset in self.subset, 'Dataset subset must be one in {}, but got {}'.format(
|
||||||
|
self.subset, subset)
|
||||||
|
self.subset = subset
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.feat_config = kwargs
|
||||||
|
self._data = self._get_data()
|
||||||
|
super(AISHELL1, self).__init__()
|
||||||
|
|
||||||
|
def _get_text_info(self) -> Dict[str, str]:
|
||||||
|
ret = {}
|
||||||
|
with open(os.path.join(DATA_HOME, self.text_meta), 'r') as rf:
|
||||||
|
for line in rf.readlines()[1:]:
|
||||||
|
utt_id, text = map(str.strip, line.split(' ',
|
||||||
|
1)) # utt_id, text
|
||||||
|
ret.update({utt_id: ''.join(text.split())})
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
|
||||||
|
not os.path.isfile(os.path.join(DATA_HOME, self.text_meta)):
|
||||||
|
download_and_decompress(self.archieves, DATA_HOME)
|
||||||
|
# Extract *wav from *.tar.gz.
|
||||||
|
for root, _, files in os.walk(
|
||||||
|
os.path.join(DATA_HOME, self.audio_path)):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith('.tar.gz'):
|
||||||
|
decompress(os.path.join(root, file))
|
||||||
|
os.remove(os.path.join(root, file))
|
||||||
|
|
||||||
|
text_info = self._get_text_info()
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for root, _, files in os.walk(
|
||||||
|
os.path.join(DATA_HOME, self.audio_path, self.subset)):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith('.wav'):
|
||||||
|
utt_id = os.path.splitext(file)[0]
|
||||||
|
if utt_id not in text_info: # There are some utt_id that without label
|
||||||
|
continue
|
||||||
|
text = text_info[utt_id]
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
data.append(self.utt_info(file_path, utt_id, text))
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_to_record(self, idx: int):
|
||||||
|
sample = self._data[idx]
|
||||||
|
|
||||||
|
record = {}
|
||||||
|
# To show all fields in a namedtuple: `type(sample)._fields`
|
||||||
|
for field in type(sample)._fields:
|
||||||
|
record[field] = getattr(sample, field)
|
||||||
|
|
||||||
|
waveform, sr = load_audio(
|
||||||
|
sample[0]) # The first element of sample is file path
|
||||||
|
feat_func = feat_funcs[self.feat_type]
|
||||||
|
feat = feat_func(
|
||||||
|
waveform, sample_rate=sr,
|
||||||
|
**self.feat_config) if feat_func else waveform
|
||||||
|
record.update({'feat': feat, 'duration': len(waveform) / sr})
|
||||||
|
return record
|
||||||
|
|
||||||
|
def create_manifest(self, prefix='manifest'):
|
||||||
|
if not os.path.isdir(os.path.join(DATA_HOME, self.manifest_path)):
|
||||||
|
os.makedirs(os.path.join(DATA_HOME, self.manifest_path))
|
||||||
|
|
||||||
|
manifest_file = os.path.join(DATA_HOME, self.manifest_path,
|
||||||
|
f'{prefix}.{self.subset}')
|
||||||
|
with codecs.open(manifest_file, 'w', 'utf-8') as f:
|
||||||
|
for idx in tqdm(range(len(self))):
|
||||||
|
record = self._convert_to_record(idx)
|
||||||
|
record_line = json.dumps(
|
||||||
|
{
|
||||||
|
'utt': record['utt_id'],
|
||||||
|
'feat': record['file_path'],
|
||||||
|
'feat_shape': (record['duration'], ),
|
||||||
|
'text': record['text']
|
||||||
|
},
|
||||||
|
ensure_ascii=False)
|
||||||
|
f.write(record_line + '\n')
|
||||||
|
logger.info(f'Manifest file {manifest_file} created.')
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
record = self._convert_to_record(idx)
|
||||||
|
return tuple(record.values())
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from ..backends import load as load_audio
|
||||||
|
from ..features import melspectrogram
|
||||||
|
from ..features import mfcc
|
||||||
|
|
||||||
|
feat_funcs = {
|
||||||
|
'raw': None,
|
||||||
|
'melspectrogram': melspectrogram,
|
||||||
|
'mfcc': mfcc,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AudioClassificationDataset(paddle.io.Dataset):
|
||||||
|
"""
|
||||||
|
Base class of audio classification dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
files: List[str],
|
||||||
|
labels: List[int],
|
||||||
|
feat_type: str='raw',
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Ags:
|
||||||
|
files (:obj:`List[str]`): A list of absolute path of audio files.
|
||||||
|
labels (:obj:`List[int]`): Labels of audio files.
|
||||||
|
feat_type (:obj:`str`, `optional`, defaults to `raw`):
|
||||||
|
It identifies the feature type that user wants to extrace of an audio file.
|
||||||
|
"""
|
||||||
|
super(AudioClassificationDataset, self).__init__()
|
||||||
|
|
||||||
|
if feat_type not in feat_funcs.keys():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.files = files
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.feat_config = kwargs # Pass keyword arguments to customize feature config
|
||||||
|
|
||||||
|
def _get_data(self, input_file: str):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _convert_to_record(self, idx):
|
||||||
|
file, label = self.files[idx], self.labels[idx]
|
||||||
|
|
||||||
|
waveform, sample_rate = load_audio(file)
|
||||||
|
feat_func = feat_funcs[self.feat_type]
|
||||||
|
|
||||||
|
record = {}
|
||||||
|
record['feat'] = feat_func(
|
||||||
|
waveform, sample_rate,
|
||||||
|
**self.feat_config) if feat_func else waveform
|
||||||
|
record['label'] = label
|
||||||
|
return record
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
record = self._convert_to_record(idx)
|
||||||
|
return np.array(record['feat']).transpose(), np.array(
|
||||||
|
record['label'], dtype=np.int64)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.files)
|
@ -0,0 +1,298 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from ..utils.download import download_and_decompress
|
||||||
|
from ..utils.env import DATA_HOME
|
||||||
|
from .dataset import AudioClassificationDataset
|
||||||
|
|
||||||
|
__all__ = ['UrbanAcousticScenes', 'UrbanAudioVisualScenes']
|
||||||
|
|
||||||
|
|
||||||
|
class UrbanAcousticScenes(AudioClassificationDataset):
|
||||||
|
"""
|
||||||
|
TAU Urban Acoustic Scenes 2020 Mobile Development dataset contains recordings from
|
||||||
|
12 European cities in 10 different acoustic scenes using 4 different devices.
|
||||||
|
Additionally, synthetic data for 11 mobile devices was created based on the original
|
||||||
|
recordings. Of the 12 cities, two are present only in the evaluation set.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
A multi-device dataset for urban acoustic scene classification
|
||||||
|
https://arxiv.org/abs/1807.09840
|
||||||
|
"""
|
||||||
|
|
||||||
|
source_url = 'https://zenodo.org/record/3819968/files/'
|
||||||
|
base_name = 'TAU-urban-acoustic-scenes-2020-mobile-development'
|
||||||
|
archieves = [
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.meta.zip',
|
||||||
|
'md5': '6eae9db553ce48e4ea246e34e50a3cf5',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.1.zip',
|
||||||
|
'md5': 'b1e85b8a908d3d6a6ab73268f385d5c8',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.2.zip',
|
||||||
|
'md5': '4310a13cc2943d6ce3f70eba7ba4c784',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.3.zip',
|
||||||
|
'md5': 'ed38956c4246abb56190c1e9b602b7b8',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.4.zip',
|
||||||
|
'md5': '97ab8560056b6816808dedc044dcc023',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.5.zip',
|
||||||
|
'md5': 'b50f5e0bfed33cd8e52cb3e7f815c6cb',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.6.zip',
|
||||||
|
'md5': 'fbf856a3a86fff7520549c899dc94372',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.7.zip',
|
||||||
|
'md5': '0dbffe7b6e45564da649378723284062',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.8.zip',
|
||||||
|
'md5': 'bb6f77832bf0bd9f786f965beb251b2e',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.9.zip',
|
||||||
|
'md5': 'a65596a5372eab10c78e08a0de797c9e',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.10.zip',
|
||||||
|
'md5': '2ad595819ffa1d56d2de4c7ed43205a6',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.11.zip',
|
||||||
|
'md5': '0ad29f7040a4e6a22cfd639b3a6738e5',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.12.zip',
|
||||||
|
'md5': 'e5f4400c6b9697295fab4cf507155a2f',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.13.zip',
|
||||||
|
'md5': '8855ab9f9896422746ab4c5d89d8da2f',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.14.zip',
|
||||||
|
'md5': '092ad744452cd3e7de78f988a3d13020',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.15.zip',
|
||||||
|
'md5': '4b5eb85f6592aebf846088d9df76b420',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.16.zip',
|
||||||
|
'md5': '2e0a89723e58a3836be019e6996ae460',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
label_list = [
|
||||||
|
'airport', 'shopping_mall', 'metro_station', 'street_pedestrian',
|
||||||
|
'public_square', 'street_traffic', 'tram', 'bus', 'metro', 'park'
|
||||||
|
]
|
||||||
|
|
||||||
|
meta = os.path.join(base_name, 'meta.csv')
|
||||||
|
meta_info = collections.namedtuple('META_INFO', (
|
||||||
|
'filename', 'scene_label', 'identifier', 'source_label'))
|
||||||
|
subset_meta = {
|
||||||
|
'train': os.path.join(base_name, 'evaluation_setup', 'fold1_train.csv'),
|
||||||
|
'dev':
|
||||||
|
os.path.join(base_name, 'evaluation_setup', 'fold1_evaluate.csv'),
|
||||||
|
'test': os.path.join(base_name, 'evaluation_setup', 'fold1_test.csv'),
|
||||||
|
}
|
||||||
|
subset_meta_info = collections.namedtuple('SUBSET_META_INFO',
|
||||||
|
('filename', 'scene_label'))
|
||||||
|
audio_path = os.path.join(base_name, 'audio')
|
||||||
|
|
||||||
|
def __init__(self, mode: str='train', feat_type: str='raw', **kwargs):
|
||||||
|
"""
|
||||||
|
Ags:
|
||||||
|
mode (:obj:`str`, `optional`, defaults to `train`):
|
||||||
|
It identifies the dataset mode (train or dev).
|
||||||
|
feat_type (:obj:`str`, `optional`, defaults to `raw`):
|
||||||
|
It identifies the feature type that user wants to extrace of an audio file.
|
||||||
|
"""
|
||||||
|
files, labels = self._get_data(mode)
|
||||||
|
super(UrbanAcousticScenes, self).__init__(
|
||||||
|
files=files, labels=labels, feat_type=feat_type, **kwargs)
|
||||||
|
|
||||||
|
def _get_meta_info(self, subset: str=None,
|
||||||
|
skip_header: bool=True) -> List[collections.namedtuple]:
|
||||||
|
if subset is None:
|
||||||
|
meta_file = self.meta
|
||||||
|
meta_info = self.meta_info
|
||||||
|
else:
|
||||||
|
assert subset in self.subset_meta, f'Subset must be one in {list(self.subset_meta.keys())}, but got {subset}.'
|
||||||
|
meta_file = self.subset_meta[subset]
|
||||||
|
meta_info = self.subset_meta_info
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
with open(os.path.join(DATA_HOME, meta_file), 'r') as rf:
|
||||||
|
lines = rf.readlines()[1:] if skip_header else rf.readlines()
|
||||||
|
for line in lines:
|
||||||
|
ret.append(meta_info(*line.strip().split('\t')))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _get_data(self, mode: str) -> Tuple[List[str], List[int]]:
|
||||||
|
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
|
||||||
|
not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
|
||||||
|
download_and_decompress(self.archieves, DATA_HOME)
|
||||||
|
|
||||||
|
meta_info = self._get_meta_info(subset=mode, skip_header=True)
|
||||||
|
|
||||||
|
files = []
|
||||||
|
labels = []
|
||||||
|
for sample in meta_info:
|
||||||
|
filename, label = sample[:2]
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
target = self.label_list.index(label)
|
||||||
|
|
||||||
|
files.append(os.path.join(DATA_HOME, self.audio_path, filename))
|
||||||
|
labels.append(int(target))
|
||||||
|
|
||||||
|
return files, labels
|
||||||
|
|
||||||
|
|
||||||
|
class UrbanAudioVisualScenes(AudioClassificationDataset):
|
||||||
|
"""
|
||||||
|
TAU Urban Audio Visual Scenes 2021 Development dataset contains synchronized audio
|
||||||
|
and video recordings from 12 European cities in 10 different scenes.
|
||||||
|
This dataset consists of 10-seconds audio and video segments from 10
|
||||||
|
acoustic scenes. The total amount of audio in the development set is 34 hours.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
A Curated Dataset of Urban Scenes for Audio-Visual Scene Analysis
|
||||||
|
https://arxiv.org/abs/2011.00030
|
||||||
|
"""
|
||||||
|
|
||||||
|
source_url = 'https://zenodo.org/record/4477542/files/'
|
||||||
|
base_name = 'TAU-urban-audio-visual-scenes-2021-development'
|
||||||
|
|
||||||
|
archieves = [
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.meta.zip',
|
||||||
|
'md5': '76e3d7ed5291b118372e06379cb2b490',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.1.zip',
|
||||||
|
'md5': '186f6273f8f69ed9dbdc18ad65ac234f',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.2.zip',
|
||||||
|
'md5': '7fd6bb63127f5785874a55aba4e77aa5',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.3.zip',
|
||||||
|
'md5': '61396bede29d7c8c89729a01a6f6b2e2',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.4.zip',
|
||||||
|
'md5': '6ddac89717fcf9c92c451868eed77fe1',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.5.zip',
|
||||||
|
'md5': 'af4820756cdf1a7d4bd6037dc034d384',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.6.zip',
|
||||||
|
'md5': 'ebd11ec24411f2a17a64723bd4aa7fff',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.7.zip',
|
||||||
|
'md5': '2be39a76aeed704d5929d020a2909efd',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'url': source_url + base_name + '.audio.8.zip',
|
||||||
|
'md5': '972d8afe0874720fc2f28086e7cb22a9',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
label_list = [
|
||||||
|
'airport', 'shopping_mall', 'metro_station', 'street_pedestrian',
|
||||||
|
'public_square', 'street_traffic', 'tram', 'bus', 'metro', 'park'
|
||||||
|
]
|
||||||
|
|
||||||
|
meta_base_path = os.path.join(base_name, base_name + '.meta')
|
||||||
|
meta = os.path.join(meta_base_path, 'meta.csv')
|
||||||
|
meta_info = collections.namedtuple('META_INFO', (
|
||||||
|
'filename_audio', 'filename_video', 'scene_label', 'identifier'))
|
||||||
|
subset_meta = {
|
||||||
|
'train':
|
||||||
|
os.path.join(meta_base_path, 'evaluation_setup', 'fold1_train.csv'),
|
||||||
|
'dev':
|
||||||
|
os.path.join(meta_base_path, 'evaluation_setup', 'fold1_evaluate.csv'),
|
||||||
|
'test':
|
||||||
|
os.path.join(meta_base_path, 'evaluation_setup', 'fold1_test.csv'),
|
||||||
|
}
|
||||||
|
subset_meta_info = collections.namedtuple('SUBSET_META_INFO', (
|
||||||
|
'filename_audio', 'filename_video', 'scene_label'))
|
||||||
|
audio_path = os.path.join(base_name, 'audio')
|
||||||
|
|
||||||
|
def __init__(self, mode: str='train', feat_type: str='raw', **kwargs):
|
||||||
|
"""
|
||||||
|
Ags:
|
||||||
|
mode (:obj:`str`, `optional`, defaults to `train`):
|
||||||
|
It identifies the dataset mode (train or dev).
|
||||||
|
feat_type (:obj:`str`, `optional`, defaults to `raw`):
|
||||||
|
It identifies the feature type that user wants to extrace of an audio file.
|
||||||
|
"""
|
||||||
|
files, labels = self._get_data(mode)
|
||||||
|
super(UrbanAudioVisualScenes, self).__init__(
|
||||||
|
files=files, labels=labels, feat_type=feat_type, **kwargs)
|
||||||
|
|
||||||
|
def _get_meta_info(self, subset: str=None,
|
||||||
|
skip_header: bool=True) -> List[collections.namedtuple]:
|
||||||
|
if subset is None:
|
||||||
|
meta_file = self.meta
|
||||||
|
meta_info = self.meta_info
|
||||||
|
else:
|
||||||
|
assert subset in self.subset_meta, f'Subset must be one in {list(self.subset_meta.keys())}, but got {subset}.'
|
||||||
|
meta_file = self.subset_meta[subset]
|
||||||
|
meta_info = self.subset_meta_info
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
with open(os.path.join(DATA_HOME, meta_file), 'r') as rf:
|
||||||
|
lines = rf.readlines()[1:] if skip_header else rf.readlines()
|
||||||
|
for line in lines:
|
||||||
|
ret.append(meta_info(*line.strip().split('\t')))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _get_data(self, mode: str) -> Tuple[List[str], List[int]]:
|
||||||
|
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
|
||||||
|
not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
|
||||||
|
download_and_decompress(self.archieves,
|
||||||
|
os.path.join(DATA_HOME, self.base_name))
|
||||||
|
|
||||||
|
meta_info = self._get_meta_info(subset=mode, skip_header=True)
|
||||||
|
|
||||||
|
files = []
|
||||||
|
labels = []
|
||||||
|
for sample in meta_info:
|
||||||
|
filename, _, label = sample[:3]
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
target = self.label_list.index(label)
|
||||||
|
|
||||||
|
files.append(os.path.join(DATA_HOME, self.audio_path, filename))
|
||||||
|
labels.append(int(target))
|
||||||
|
|
||||||
|
return files, labels
|
@ -0,0 +1,152 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from ..utils.download import download_and_decompress
|
||||||
|
from ..utils.env import DATA_HOME
|
||||||
|
from .dataset import AudioClassificationDataset
|
||||||
|
|
||||||
|
__all__ = ['ESC50']
|
||||||
|
|
||||||
|
|
||||||
|
class ESC50(AudioClassificationDataset):
|
||||||
|
"""
|
||||||
|
The ESC-50 dataset is a labeled collection of 2000 environmental audio recordings
|
||||||
|
suitable for benchmarking methods of environmental sound classification. The dataset
|
||||||
|
consists of 5-second-long recordings organized into 50 semantical classes (with
|
||||||
|
40 examples per class)
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
ESC: Dataset for Environmental Sound Classification
|
||||||
|
http://dx.doi.org/10.1145/2733373.2806390
|
||||||
|
"""
|
||||||
|
|
||||||
|
archieves = [
|
||||||
|
{
|
||||||
|
'url':
|
||||||
|
'https://paddleaudio.bj.bcebos.com/datasets/ESC-50-master.zip',
|
||||||
|
'md5': '7771e4b9d86d0945acce719c7a59305a',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
label_list = [
|
||||||
|
# Animals
|
||||||
|
'Dog',
|
||||||
|
'Rooster',
|
||||||
|
'Pig',
|
||||||
|
'Cow',
|
||||||
|
'Frog',
|
||||||
|
'Cat',
|
||||||
|
'Hen',
|
||||||
|
'Insects (flying)',
|
||||||
|
'Sheep',
|
||||||
|
'Crow',
|
||||||
|
# Natural soundscapes & water sounds
|
||||||
|
'Rain',
|
||||||
|
'Sea waves',
|
||||||
|
'Crackling fire',
|
||||||
|
'Crickets',
|
||||||
|
'Chirping birds',
|
||||||
|
'Water drops',
|
||||||
|
'Wind',
|
||||||
|
'Pouring water',
|
||||||
|
'Toilet flush',
|
||||||
|
'Thunderstorm',
|
||||||
|
# Human, non-speech sounds
|
||||||
|
'Crying baby',
|
||||||
|
'Sneezing',
|
||||||
|
'Clapping',
|
||||||
|
'Breathing',
|
||||||
|
'Coughing',
|
||||||
|
'Footsteps',
|
||||||
|
'Laughing',
|
||||||
|
'Brushing teeth',
|
||||||
|
'Snoring',
|
||||||
|
'Drinking, sipping',
|
||||||
|
# Interior/domestic sounds
|
||||||
|
'Door knock',
|
||||||
|
'Mouse click',
|
||||||
|
'Keyboard typing',
|
||||||
|
'Door, wood creaks',
|
||||||
|
'Can opening',
|
||||||
|
'Washing machine',
|
||||||
|
'Vacuum cleaner',
|
||||||
|
'Clock alarm',
|
||||||
|
'Clock tick',
|
||||||
|
'Glass breaking',
|
||||||
|
# Exterior/urban noises
|
||||||
|
'Helicopter',
|
||||||
|
'Chainsaw',
|
||||||
|
'Siren',
|
||||||
|
'Car horn',
|
||||||
|
'Engine',
|
||||||
|
'Train',
|
||||||
|
'Church bells',
|
||||||
|
'Airplane',
|
||||||
|
'Fireworks',
|
||||||
|
'Hand saw',
|
||||||
|
]
|
||||||
|
meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
|
||||||
|
meta_info = collections.namedtuple(
|
||||||
|
'META_INFO',
|
||||||
|
('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'))
|
||||||
|
audio_path = os.path.join('ESC-50-master', 'audio')
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
mode: str='train',
|
||||||
|
split: int=1,
|
||||||
|
feat_type: str='raw',
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Ags:
|
||||||
|
mode (:obj:`str`, `optional`, defaults to `train`):
|
||||||
|
It identifies the dataset mode (train or dev).
|
||||||
|
split (:obj:`int`, `optional`, defaults to 1):
|
||||||
|
It specify the fold of dev dataset.
|
||||||
|
feat_type (:obj:`str`, `optional`, defaults to `raw`):
|
||||||
|
It identifies the feature type that user wants to extrace of an audio file.
|
||||||
|
"""
|
||||||
|
files, labels = self._get_data(mode, split)
|
||||||
|
super(ESC50, self).__init__(
|
||||||
|
files=files, labels=labels, feat_type=feat_type, **kwargs)
|
||||||
|
|
||||||
|
def _get_meta_info(self) -> List[collections.namedtuple]:
|
||||||
|
ret = []
|
||||||
|
with open(os.path.join(DATA_HOME, self.meta), 'r') as rf:
|
||||||
|
for line in rf.readlines()[1:]:
|
||||||
|
ret.append(self.meta_info(*line.strip().split(',')))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]:
|
||||||
|
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
|
||||||
|
not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
|
||||||
|
download_and_decompress(self.archieves, DATA_HOME)
|
||||||
|
|
||||||
|
meta_info = self._get_meta_info()
|
||||||
|
|
||||||
|
files = []
|
||||||
|
labels = []
|
||||||
|
for sample in meta_info:
|
||||||
|
filename, fold, target, _, _, _, _ = sample
|
||||||
|
if mode == 'train' and int(fold) != split:
|
||||||
|
files.append(os.path.join(DATA_HOME, self.audio_path, filename))
|
||||||
|
labels.append(int(target))
|
||||||
|
|
||||||
|
if mode != 'train' and int(fold) == split:
|
||||||
|
files.append(os.path.join(DATA_HOME, self.audio_path, filename))
|
||||||
|
labels.append(int(target))
|
||||||
|
|
||||||
|
return files, labels
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue