Merge pull request #929 from PaddlePaddle/join_ctc
[lm] transformer lm & kaldi data processpull/945/head
commit
e8bc9a2a08
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,263 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
|
||||
from deepspeech.models.lm_interface import LMInterface
|
||||
from deepspeech.modules.encoder import TransformerEncoder
|
||||
from deepspeech.modules.mask import subsequent_mask
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab: int,
|
||||
pos_enc: str=None,
|
||||
embed_unit: int=128,
|
||||
att_unit: int=256,
|
||||
head: int=2,
|
||||
unit: int=1024,
|
||||
layer: int=4,
|
||||
dropout_rate: float=0.5,
|
||||
emb_dropout_rate: float=0.0,
|
||||
att_dropout_rate: float=0.0,
|
||||
tie_weights: bool=False,
|
||||
**kwargs):
|
||||
nn.Layer.__init__(self)
|
||||
|
||||
if pos_enc == "sinusoidal":
|
||||
pos_enc_layer_type = "abs_pos"
|
||||
elif pos_enc is None:
|
||||
pos_enc_layer_type = "no_pos"
|
||||
else:
|
||||
raise ValueError(f"unknown pos-enc option: {pos_enc}")
|
||||
|
||||
self.embed = nn.Embedding(n_vocab, embed_unit)
|
||||
|
||||
if emb_dropout_rate == 0.0:
|
||||
self.embed_drop = None
|
||||
else:
|
||||
self.embed_drop = nn.Dropout(emb_dropout_rate)
|
||||
|
||||
self.encoder = TransformerEncoder(
|
||||
input_size=embed_unit,
|
||||
output_size=att_unit,
|
||||
attention_heads=head,
|
||||
linear_units=unit,
|
||||
num_blocks=layer,
|
||||
dropout_rate=dropout_rate,
|
||||
attention_dropout_rate=att_dropout_rate,
|
||||
input_layer="linear",
|
||||
pos_enc_layer_type=pos_enc_layer_type,
|
||||
concat_after=False,
|
||||
static_chunk_size=1,
|
||||
use_dynamic_chunk=False,
|
||||
use_dynamic_left_chunk=False)
|
||||
|
||||
self.decoder = nn.Linear(att_unit, n_vocab)
|
||||
|
||||
logger.info("Tie weights set to {}".format(tie_weights))
|
||||
logger.info("Dropout set to {}".format(dropout_rate))
|
||||
logger.info("Emb Dropout set to {}".format(emb_dropout_rate))
|
||||
logger.info("Att Dropout set to {}".format(att_dropout_rate))
|
||||
|
||||
if tie_weights:
|
||||
assert (
|
||||
att_unit == embed_unit
|
||||
), "Tie Weights: True need embedding and final dimensions to match"
|
||||
self.decoder.weight = self.embed.weight
|
||||
|
||||
def _target_mask(self, ys_in_pad):
|
||||
ys_mask = ys_in_pad != 0
|
||||
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
|
||||
return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def forward(self, x: paddle.Tensor, t: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
"""Compute LM loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): Input ids. (batch, len)
|
||||
t (paddle.Tensor): Target ids. (batch, len)
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of
|
||||
loss to backward (scalar),
|
||||
negative log-likelihood of t: -log p(t) (scalar) and
|
||||
the number of elements in x (scalar)
|
||||
|
||||
Notes:
|
||||
The last two return values are used
|
||||
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
|
||||
|
||||
"""
|
||||
xm = x != 0
|
||||
xlen = xm.sum(axis=1)
|
||||
if self.embed_drop is not None:
|
||||
emb = self.embed_drop(self.embed(x))
|
||||
else:
|
||||
emb = self.embed(x)
|
||||
h, _ = self.encoder(emb, xlen)
|
||||
y = self.decoder(h)
|
||||
loss = F.cross_entropy(
|
||||
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
||||
mask = xm.to(dtype=loss.dtype)
|
||||
logp = loss * mask.view(-1)
|
||||
logp = logp.sum()
|
||||
count = mask.sum()
|
||||
return logp / count, logp, count
|
||||
|
||||
# beam search API (see ScorerInterface)
|
||||
def score(self, y: paddle.Tensor, state: Any,
|
||||
x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (paddle.Tensor): encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, Any]: Tuple of
|
||||
paddle.float32 scores for next token (n_vocab)
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
y = y.unsqueeze(0)
|
||||
|
||||
if self.embed_drop is not None:
|
||||
emb = self.embed_drop(self.embed(y))
|
||||
else:
|
||||
emb = self.embed(y)
|
||||
|
||||
h, _, cache = self.encoder.forward_one_step(
|
||||
emb, self._target_mask(y), cache=state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = F.log_softmax(h).squeeze(0)
|
||||
return logp, cache
|
||||
|
||||
# batch beam search API (see BatchScorerInterface)
|
||||
def batch_score(self,
|
||||
ys: paddle.Tensor,
|
||||
states: List[Any],
|
||||
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
|
||||
"""Score new token batch (required).
|
||||
|
||||
Args:
|
||||
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (paddle.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.encoder.encoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [
|
||||
paddle.stack([states[b][i] for b in range(n_batch)])
|
||||
for i in range(n_layers)
|
||||
]
|
||||
|
||||
if self.embed_drop is not None:
|
||||
emb = self.embed_drop(self.embed(ys))
|
||||
else:
|
||||
emb = self.embed(ys)
|
||||
|
||||
# batch decoding
|
||||
h, _, states = self.encoder.forward_one_step(
|
||||
emb, self._target_mask(ys), cache=batch_state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = F.log_softmax(h)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)]
|
||||
for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tlm = TransformerLM(
|
||||
n_vocab=5002,
|
||||
pos_enc=None,
|
||||
embed_unit=128,
|
||||
att_unit=512,
|
||||
head=8,
|
||||
unit=2048,
|
||||
layer=16,
|
||||
dropout_rate=0.5, )
|
||||
|
||||
# n_vocab: int,
|
||||
# pos_enc: str=None,
|
||||
# embed_unit: int=128,
|
||||
# att_unit: int=256,
|
||||
# head: int=2,
|
||||
# unit: int=1024,
|
||||
# layer: int=4,
|
||||
# dropout_rate: float=0.5,
|
||||
# emb_dropout_rate: float = 0.0,
|
||||
# att_dropout_rate: float = 0.0,
|
||||
# tie_weights: bool = False,):
|
||||
paddle.set_device("cpu")
|
||||
model_dict = paddle.load("transformerLM.pdparams")
|
||||
tlm.set_state_dict(model_dict)
|
||||
|
||||
tlm.eval()
|
||||
#Test the score
|
||||
input2 = np.array([5])
|
||||
input2 = paddle.to_tensor(input2)
|
||||
state = None
|
||||
output, state = tlm.score(input2, state, None)
|
||||
|
||||
input3 = np.array([5, 10])
|
||||
input3 = paddle.to_tensor(input3)
|
||||
output, state = tlm.score(input3, state, None)
|
||||
|
||||
input4 = np.array([5, 10, 0])
|
||||
input4 = paddle.to_tensor(input4)
|
||||
output, state = tlm.score(input4, state, None)
|
||||
print("output", output)
|
||||
"""
|
||||
#Test the batch score
|
||||
batch_size = 2
|
||||
inp2 = np.array([[5], [10]])
|
||||
inp2 = paddle.to_tensor(inp2)
|
||||
output, states = tlm.batch_score(
|
||||
inp2, [(None,None,0)] * batch_size)
|
||||
inp3 = np.array([[100], [30]])
|
||||
inp3 = paddle.to_tensor(inp3)
|
||||
output, states = tlm.batch_score(
|
||||
inp3, states)
|
||||
print("output", output)
|
||||
#print("cache", cache)
|
||||
#np.save("output_pd.npy", output)
|
||||
"""
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Language model interface."""
|
||||
import argparse
|
||||
|
||||
from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
|
||||
from deepspeech.utils.dynamic_import import dynamic_import
|
||||
|
||||
|
||||
class LMInterface(ScorerInterface):
|
||||
"""LM Interface model implementation."""
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser):
|
||||
"""Add arguments to command line argument parser."""
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def build(cls, n_vocab: int, **kwargs):
|
||||
"""Initialize this class with python-level args.
|
||||
|
||||
Args:
|
||||
idim (int): The number of vocabulary.
|
||||
|
||||
Returns:
|
||||
LMinterface: A new instance of LMInterface.
|
||||
|
||||
"""
|
||||
args = argparse.Namespace(**kwargs)
|
||||
return cls(n_vocab, args)
|
||||
|
||||
def forward(self, x, t):
|
||||
"""Compute LM loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input ids. (batch, len)
|
||||
t (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
|
||||
loss to backward (scalar),
|
||||
negative log-likelihood of t: -log p(t) (scalar) and
|
||||
the number of elements in x (scalar)
|
||||
|
||||
Notes:
|
||||
The last two return values are used
|
||||
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
|
||||
|
||||
"""
|
||||
raise NotImplementedError("forward method is not implemented")
|
||||
|
||||
|
||||
predefined_lms = {
|
||||
"transformer": "deepspeech.models.lm.transformer:TransformerLM",
|
||||
}
|
||||
|
||||
|
||||
def dynamic_import_lm(module):
|
||||
"""Import LM class dynamically.
|
||||
|
||||
Args:
|
||||
module (str): module_name:class_name or alias in `predefined_lms`
|
||||
|
||||
Returns:
|
||||
type: LM class
|
||||
|
||||
"""
|
||||
model_class = dynamic_import(module, predefined_lms)
|
||||
assert issubclass(model_class,
|
||||
LMInterface), f"{module} does not implement LMInterface"
|
||||
return model_class
|
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ST Interface module."""
|
||||
from .asr_interface import ASRInterface
|
||||
from deepspeech.utils.dynamic_import import dynamic_import
|
||||
|
||||
|
||||
class STInterface(ASRInterface):
|
||||
"""ST Interface model implementation.
|
||||
|
||||
NOTE: This class is inherited from ASRInterface to enable joint translation
|
||||
and recognition when performing multi-task learning with the ASR task.
|
||||
|
||||
"""
|
||||
|
||||
def translate(self,
|
||||
x,
|
||||
trans_args,
|
||||
char_list=None,
|
||||
rnnlm=None,
|
||||
ensemble_models=[]):
|
||||
"""Recognize x for evaluation.
|
||||
|
||||
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
|
||||
:param namespace trans_args: argment namespace contraining options
|
||||
:param list char_list: list of characters
|
||||
:param paddle.nn.Layer rnnlm: language model module
|
||||
:return: N-best decoding results
|
||||
:rtype: list
|
||||
"""
|
||||
raise NotImplementedError("translate method is not implemented")
|
||||
|
||||
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
|
||||
"""Beam search implementation for batch.
|
||||
|
||||
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
|
||||
:param namespace trans_args: argument namespace containing options
|
||||
:param list char_list: list of characters
|
||||
:param paddle.nn.Layer rnnlm: language model module
|
||||
:return: N-best decoding results
|
||||
:rtype: list
|
||||
"""
|
||||
raise NotImplementedError("Batch decoding is not supported yet.")
|
||||
|
||||
|
||||
predefined_st = {
|
||||
"transformer": "deepspeech.models.u2_st:U2STModel",
|
||||
}
|
||||
|
||||
|
||||
def dynamic_import_st(module):
|
||||
"""Import ST models dynamically.
|
||||
|
||||
Args:
|
||||
module (str): module_name:class_name or alias in `predefined_st`
|
||||
|
||||
Returns:
|
||||
type: ST class
|
||||
|
||||
"""
|
||||
model_class = dynamic_import(module, predefined_st)
|
||||
assert issubclass(model_class,
|
||||
STInterface), f"{module} does not implement STInterface"
|
||||
return model_class
|
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .u2_st import U2STInferModel
|
||||
from .u2_st import U2STModel
|
@ -0,0 +1,418 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import extension
|
||||
|
||||
|
||||
class PlotAttentionReport(extension.Extension):
|
||||
"""Plot attention reporter.
|
||||
|
||||
Args:
|
||||
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
|
||||
Function of attention visualization.
|
||||
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
||||
outdir (str): Directory to save figures.
|
||||
converter (espnet.asr.*_backend.asr.CustomConverter):
|
||||
Function to convert data.
|
||||
device (int | torch.device): Device.
|
||||
reverse (bool): If True, input and output length are reversed.
|
||||
ikey (str): Key to access input
|
||||
(for ASR/ST ikey="input", for MT ikey="output".)
|
||||
iaxis (int): Dimension to access input
|
||||
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
||||
okey (str): Key to access output
|
||||
(for ASR/ST okey="input", MT okay="output".)
|
||||
oaxis (int): Dimension to access output
|
||||
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
||||
subsampling_factor (int): subsampling factor in encoder
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
att_vis_fn,
|
||||
data,
|
||||
outdir,
|
||||
converter,
|
||||
transform,
|
||||
device,
|
||||
reverse=False,
|
||||
ikey="input",
|
||||
iaxis=0,
|
||||
okey="output",
|
||||
oaxis=0,
|
||||
subsampling_factor=1, ):
|
||||
self.att_vis_fn = att_vis_fn
|
||||
self.data = copy.deepcopy(data)
|
||||
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
||||
# key is utterance ID
|
||||
self.outdir = outdir
|
||||
self.converter = converter
|
||||
self.transform = transform
|
||||
self.device = device
|
||||
self.reverse = reverse
|
||||
self.ikey = ikey
|
||||
self.iaxis = iaxis
|
||||
self.okey = okey
|
||||
self.oaxis = oaxis
|
||||
self.factor = subsampling_factor
|
||||
if not os.path.exists(self.outdir):
|
||||
os.makedirs(self.outdir)
|
||||
|
||||
def __call__(self, trainer):
|
||||
"""Plot and save image file of att_ws matrix."""
|
||||
att_ws, uttid_list = self.get_attention_weights()
|
||||
if isinstance(att_ws, list): # multi-encoder case
|
||||
num_encs = len(att_ws) - 1
|
||||
# atts
|
||||
for i in range(num_encs):
|
||||
for idx, att_w in enumerate(att_ws[i]):
|
||||
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
|
||||
self.outdir, uttid_list[idx], i + 1, )
|
||||
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
|
||||
self.outdir, uttid_list[idx], i + 1, )
|
||||
np.save(np_filename.format(trainer), att_w)
|
||||
self._plot_and_save_attention(att_w,
|
||||
filename.format(trainer))
|
||||
# han
|
||||
for idx, att_w in enumerate(att_ws[num_encs]):
|
||||
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
|
||||
self.outdir, uttid_list[idx], )
|
||||
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
|
||||
self.outdir, uttid_list[idx], )
|
||||
np.save(np_filename.format(trainer), att_w)
|
||||
self._plot_and_save_attention(
|
||||
att_w, filename.format(trainer), han_mode=True)
|
||||
else:
|
||||
for idx, att_w in enumerate(att_ws):
|
||||
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
|
||||
uttid_list[idx], )
|
||||
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
||||
self.outdir, uttid_list[idx], )
|
||||
np.save(np_filename.format(trainer), att_w)
|
||||
self._plot_and_save_attention(att_w, filename.format(trainer))
|
||||
|
||||
def log_attentions(self, logger, step):
|
||||
"""Add image files of att_ws matrix to the tensorboard."""
|
||||
att_ws, uttid_list = self.get_attention_weights()
|
||||
if isinstance(att_ws, list): # multi-encoder case
|
||||
num_encs = len(att_ws) - 1
|
||||
# atts
|
||||
for i in range(num_encs):
|
||||
for idx, att_w in enumerate(att_ws[i]):
|
||||
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||
plot = self.draw_attention_plot(att_w)
|
||||
logger.add_figure(
|
||||
"%s_att%d" % (uttid_list[idx], i + 1),
|
||||
plot.gcf(),
|
||||
step, )
|
||||
# han
|
||||
for idx, att_w in enumerate(att_ws[num_encs]):
|
||||
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||
plot = self.draw_han_plot(att_w)
|
||||
logger.add_figure(
|
||||
"%s_han" % (uttid_list[idx]),
|
||||
plot.gcf(),
|
||||
step, )
|
||||
else:
|
||||
for idx, att_w in enumerate(att_ws):
|
||||
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||
plot = self.draw_attention_plot(att_w)
|
||||
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
||||
|
||||
def get_attention_weights(self):
|
||||
"""Return attention weights.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: attention weights. float. Its shape would be
|
||||
differ from backend.
|
||||
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
|
||||
other case => (B, Lmax, Tmax).
|
||||
* chainer-> (B, Lmax, Tmax)
|
||||
|
||||
"""
|
||||
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
||||
batch = self.converter([return_batch], self.device)
|
||||
if isinstance(batch, tuple):
|
||||
att_ws = self.att_vis_fn(*batch)
|
||||
else:
|
||||
att_ws = self.att_vis_fn(**batch)
|
||||
return att_ws, uttid_list
|
||||
|
||||
def trim_attention_weight(self, uttid, att_w):
|
||||
"""Transform attention matrix with regard to self.reverse."""
|
||||
if self.reverse:
|
||||
enc_key, enc_axis = self.okey, self.oaxis
|
||||
dec_key, dec_axis = self.ikey, self.iaxis
|
||||
else:
|
||||
enc_key, enc_axis = self.ikey, self.iaxis
|
||||
dec_key, dec_axis = self.okey, self.oaxis
|
||||
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
|
||||
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
|
||||
if self.factor > 1:
|
||||
enc_len //= self.factor
|
||||
if len(att_w.shape) == 3:
|
||||
att_w = att_w[:, :dec_len, :enc_len]
|
||||
else:
|
||||
att_w = att_w[:dec_len, :enc_len]
|
||||
return att_w
|
||||
|
||||
def draw_attention_plot(self, att_w):
|
||||
"""Plot the att_w matrix.
|
||||
|
||||
Returns:
|
||||
matplotlib.pyplot: pyplot object with attention matrix image.
|
||||
|
||||
"""
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.clf()
|
||||
att_w = att_w.astype(np.float32)
|
||||
if len(att_w.shape) == 3:
|
||||
for h, aw in enumerate(att_w, 1):
|
||||
plt.subplot(1, len(att_w), h)
|
||||
plt.imshow(aw, aspect="auto")
|
||||
plt.xlabel("Encoder Index")
|
||||
plt.ylabel("Decoder Index")
|
||||
else:
|
||||
plt.imshow(att_w, aspect="auto")
|
||||
plt.xlabel("Encoder Index")
|
||||
plt.ylabel("Decoder Index")
|
||||
plt.tight_layout()
|
||||
return plt
|
||||
|
||||
def draw_han_plot(self, att_w):
|
||||
"""Plot the att_w matrix for hierarchical attention.
|
||||
|
||||
Returns:
|
||||
matplotlib.pyplot: pyplot object with attention matrix image.
|
||||
|
||||
"""
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.clf()
|
||||
if len(att_w.shape) == 3:
|
||||
for h, aw in enumerate(att_w, 1):
|
||||
legends = []
|
||||
plt.subplot(1, len(att_w), h)
|
||||
for i in range(aw.shape[1]):
|
||||
plt.plot(aw[:, i])
|
||||
legends.append("Att{}".format(i))
|
||||
plt.ylim([0, 1.0])
|
||||
plt.xlim([0, aw.shape[0]])
|
||||
plt.grid(True)
|
||||
plt.ylabel("Attention Weight")
|
||||
plt.xlabel("Decoder Index")
|
||||
plt.legend(legends)
|
||||
else:
|
||||
legends = []
|
||||
for i in range(att_w.shape[1]):
|
||||
plt.plot(att_w[:, i])
|
||||
legends.append("Att{}".format(i))
|
||||
plt.ylim([0, 1.0])
|
||||
plt.xlim([0, att_w.shape[0]])
|
||||
plt.grid(True)
|
||||
plt.ylabel("Attention Weight")
|
||||
plt.xlabel("Decoder Index")
|
||||
plt.legend(legends)
|
||||
plt.tight_layout()
|
||||
return plt
|
||||
|
||||
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
|
||||
if han_mode:
|
||||
plt = self.draw_han_plot(att_w)
|
||||
else:
|
||||
plt = self.draw_attention_plot(att_w)
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
||||
|
||||
|
||||
class PlotCTCReport(extension.Extension):
|
||||
"""Plot CTC reporter.
|
||||
|
||||
Args:
|
||||
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
|
||||
Function of CTC visualization.
|
||||
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
||||
outdir (str): Directory to save figures.
|
||||
converter (espnet.asr.*_backend.asr.CustomConverter):
|
||||
Function to convert data.
|
||||
device (int | torch.device): Device.
|
||||
reverse (bool): If True, input and output length are reversed.
|
||||
ikey (str): Key to access input
|
||||
(for ASR/ST ikey="input", for MT ikey="output".)
|
||||
iaxis (int): Dimension to access input
|
||||
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
||||
okey (str): Key to access output
|
||||
(for ASR/ST okey="input", MT okay="output".)
|
||||
oaxis (int): Dimension to access output
|
||||
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
||||
subsampling_factor (int): subsampling factor in encoder
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctc_vis_fn,
|
||||
data,
|
||||
outdir,
|
||||
converter,
|
||||
transform,
|
||||
device,
|
||||
reverse=False,
|
||||
ikey="input",
|
||||
iaxis=0,
|
||||
okey="output",
|
||||
oaxis=0,
|
||||
subsampling_factor=1, ):
|
||||
self.ctc_vis_fn = ctc_vis_fn
|
||||
self.data = copy.deepcopy(data)
|
||||
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
||||
# key is utterance ID
|
||||
self.outdir = outdir
|
||||
self.converter = converter
|
||||
self.transform = transform
|
||||
self.device = device
|
||||
self.reverse = reverse
|
||||
self.ikey = ikey
|
||||
self.iaxis = iaxis
|
||||
self.okey = okey
|
||||
self.oaxis = oaxis
|
||||
self.factor = subsampling_factor
|
||||
if not os.path.exists(self.outdir):
|
||||
os.makedirs(self.outdir)
|
||||
|
||||
def __call__(self, trainer):
|
||||
"""Plot and save image file of ctc prob."""
|
||||
ctc_probs, uttid_list = self.get_ctc_probs()
|
||||
if isinstance(ctc_probs, list): # multi-encoder case
|
||||
num_encs = len(ctc_probs) - 1
|
||||
for i in range(num_encs):
|
||||
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
||||
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
|
||||
self.outdir, uttid_list[idx], i + 1, )
|
||||
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
|
||||
self.outdir, uttid_list[idx], i + 1, )
|
||||
np.save(np_filename.format(trainer), ctc_prob)
|
||||
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
||||
else:
|
||||
for idx, ctc_prob in enumerate(ctc_probs):
|
||||
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
|
||||
uttid_list[idx], )
|
||||
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
||||
self.outdir, uttid_list[idx], )
|
||||
np.save(np_filename.format(trainer), ctc_prob)
|
||||
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
||||
|
||||
def log_ctc_probs(self, logger, step):
|
||||
"""Add image files of ctc probs to the tensorboard."""
|
||||
ctc_probs, uttid_list = self.get_ctc_probs()
|
||||
if isinstance(ctc_probs, list): # multi-encoder case
|
||||
num_encs = len(ctc_probs) - 1
|
||||
for i in range(num_encs):
|
||||
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
||||
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||
plot = self.draw_ctc_plot(ctc_prob)
|
||||
logger.add_figure(
|
||||
"%s_ctc%d" % (uttid_list[idx], i + 1),
|
||||
plot.gcf(),
|
||||
step, )
|
||||
else:
|
||||
for idx, ctc_prob in enumerate(ctc_probs):
|
||||
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||
plot = self.draw_ctc_plot(ctc_prob)
|
||||
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
||||
|
||||
def get_ctc_probs(self):
|
||||
"""Return CTC probs.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: CTC probs. float. Its shape would be
|
||||
differ from backend. (B, Tmax, vocab).
|
||||
|
||||
"""
|
||||
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
||||
batch = self.converter([return_batch], self.device)
|
||||
if isinstance(batch, tuple):
|
||||
probs = self.ctc_vis_fn(*batch)
|
||||
else:
|
||||
probs = self.ctc_vis_fn(**batch)
|
||||
return probs, uttid_list
|
||||
|
||||
def trim_ctc_prob(self, uttid, prob):
|
||||
"""Trim CTC posteriors accoding to input lengths."""
|
||||
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
|
||||
if self.factor > 1:
|
||||
enc_len //= self.factor
|
||||
prob = prob[:enc_len]
|
||||
return prob
|
||||
|
||||
def draw_ctc_plot(self, ctc_prob):
|
||||
"""Plot the ctc_prob matrix.
|
||||
|
||||
Returns:
|
||||
matplotlib.pyplot: pyplot object with CTC prob matrix image.
|
||||
|
||||
"""
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
ctc_prob = ctc_prob.astype(np.float32)
|
||||
|
||||
plt.clf()
|
||||
topk_ids = np.argsort(ctc_prob, axis=1)
|
||||
n_frames, vocab = ctc_prob.shape
|
||||
times_probs = np.arange(n_frames)
|
||||
|
||||
plt.figure(figsize=(20, 8))
|
||||
|
||||
# NOTE: index 0 is reserved for blank
|
||||
for idx in set(topk_ids.reshape(-1).tolist()):
|
||||
if idx == 0:
|
||||
plt.plot(
|
||||
times_probs,
|
||||
ctc_prob[:, 0],
|
||||
":",
|
||||
label="<blank>",
|
||||
color="grey")
|
||||
else:
|
||||
plt.plot(times_probs, ctc_prob[:, idx])
|
||||
plt.xlabel(u"Input [frame]", fontsize=12)
|
||||
plt.ylabel("Posteriors", fontsize=12)
|
||||
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
|
||||
plt.yticks(list(range(0, 2, 1)))
|
||||
plt.tight_layout()
|
||||
return plt
|
||||
|
||||
def _plot_and_save_ctc(self, ctc_prob, filename):
|
||||
plt = self.draw_ctc_plot(ctc_prob)
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..reporter import DictSummary
|
||||
from .utils import get_trigger
|
||||
|
||||
|
||||
class CompareValueTrigger():
|
||||
"""Trigger invoked when key value getting bigger or lower than before.
|
||||
|
||||
Args:
|
||||
key (str) : Key of value.
|
||||
compare_fn ((float, float) -> bool) : Function to compare the values.
|
||||
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
|
||||
self._key = key
|
||||
self._best_value = None
|
||||
self._interval_trigger = get_trigger(trigger)
|
||||
self._init_summary()
|
||||
self._compare_fn = compare_fn
|
||||
|
||||
def __call__(self, trainer):
|
||||
"""Get value related to the key and compare with current value."""
|
||||
observation = trainer.observation
|
||||
summary = self._summary
|
||||
key = self._key
|
||||
if key in observation:
|
||||
summary.add({key: observation[key]})
|
||||
|
||||
if not self._interval_trigger(trainer):
|
||||
return False
|
||||
|
||||
stats = summary.compute_mean()
|
||||
value = float(stats[key]) # copy to CPU
|
||||
self._init_summary()
|
||||
|
||||
if self._best_value is None:
|
||||
# initialize best value
|
||||
self._best_value = value
|
||||
return False
|
||||
elif self._compare_fn(self._best_value, value):
|
||||
return True
|
||||
else:
|
||||
self._best_value = value
|
||||
return False
|
||||
|
||||
def _init_summary(self):
|
||||
self._summary = DictSummary()
|
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .interval_trigger import IntervalTrigger
|
||||
|
||||
|
||||
def never_fail_trigger(trainer):
|
||||
return False
|
||||
|
||||
|
||||
def get_trigger(trigger):
|
||||
if trigger is None:
|
||||
return never_fail_trigger
|
||||
if callable(trigger):
|
||||
return trigger
|
||||
else:
|
||||
trigger = IntervalTrigger(*trigger)
|
||||
return trigger
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,41 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def delta(feat, window):
|
||||
assert window > 0
|
||||
delta_feat = np.zeros_like(feat)
|
||||
for i in range(1, window + 1):
|
||||
delta_feat[:-i] += i * feat[i:]
|
||||
delta_feat[i:] += -i * feat[:-i]
|
||||
delta_feat[-i:] += i * feat[-1]
|
||||
delta_feat[:i] += -i * feat[0]
|
||||
delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1))
|
||||
return delta_feat
|
||||
|
||||
|
||||
def add_deltas(x, window=2, order=2):
|
||||
"""
|
||||
Args:
|
||||
x (np.ndarray): speech feat, (T, D).
|
||||
|
||||
Return:
|
||||
np.ndarray: (T, (1+order)*D)
|
||||
"""
|
||||
feats = [x]
|
||||
for _ in range(order):
|
||||
feats.append(delta(feats[-1], window))
|
||||
return np.concatenate(feats, axis=1)
|
||||
|
||||
|
||||
class AddDeltas():
|
||||
def __init__(self, window=2, order=2):
|
||||
self.window = window
|
||||
self.order = order
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(window={window}, order={order}".format(
|
||||
name=self.__class__.__name__, window=self.window, order=self.order
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return add_deltas(x, window=self.window, order=self.order)
|
@ -0,0 +1,45 @@
|
||||
import numpy
|
||||
|
||||
|
||||
class ChannelSelector():
|
||||
"""Select 1ch from multi-channel signal"""
|
||||
|
||||
def __init__(self, train_channel="random", eval_channel=0, axis=1):
|
||||
self.train_channel = train_channel
|
||||
self.eval_channel = eval_channel
|
||||
self.axis = axis
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(train_channel={train_channel}, "
|
||||
"eval_channel={eval_channel}, axis={axis})".format(
|
||||
name=self.__class__.__name__,
|
||||
train_channel=self.train_channel,
|
||||
eval_channel=self.eval_channel,
|
||||
axis=self.axis,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x, train=True):
|
||||
# Assuming x: [Time, Channel] by default
|
||||
|
||||
if x.ndim <= self.axis:
|
||||
# If the dimension is insufficient, then unsqueeze
|
||||
# (e.g [Time] -> [Time, 1])
|
||||
ind = tuple(
|
||||
slice(None) if i < x.ndim else None for i in range(self.axis + 1)
|
||||
)
|
||||
x = x[ind]
|
||||
|
||||
if train:
|
||||
channel = self.train_channel
|
||||
else:
|
||||
channel = self.eval_channel
|
||||
|
||||
if channel == "random":
|
||||
ch = numpy.random.randint(0, x.shape[self.axis])
|
||||
else:
|
||||
ch = channel
|
||||
|
||||
ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim))
|
||||
return x[ind]
|
@ -0,0 +1,158 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
|
||||
import h5py
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CMVN():
|
||||
"Apply Global/Spk CMVN/iverserCMVN."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats,
|
||||
norm_means=True,
|
||||
norm_vars=False,
|
||||
filetype="mat",
|
||||
utt2spk=None,
|
||||
spk2utt=None,
|
||||
reverse=False,
|
||||
std_floor=1.0e-20, ):
|
||||
self.stats_file = stats
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.reverse = reverse
|
||||
|
||||
if isinstance(stats, dict):
|
||||
stats_dict = dict(stats)
|
||||
else:
|
||||
# Use for global CMVN
|
||||
if filetype == "mat":
|
||||
stats_dict = {None: kaldiio.load_mat(stats)}
|
||||
# Use for global CMVN
|
||||
elif filetype == "npy":
|
||||
stats_dict = {None: np.load(stats)}
|
||||
# Use for speaker CMVN
|
||||
elif filetype == "ark":
|
||||
self.accept_uttid = True
|
||||
stats_dict = dict(kaldiio.load_ark(stats))
|
||||
# Use for speaker CMVN
|
||||
elif filetype == "hdf5":
|
||||
self.accept_uttid = True
|
||||
stats_dict = h5py.File(stats)
|
||||
else:
|
||||
raise ValueError("Not supporting filetype={}".format(filetype))
|
||||
|
||||
if utt2spk is not None:
|
||||
self.utt2spk = {}
|
||||
with io.open(utt2spk, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
utt, spk = line.rstrip().split(None, 1)
|
||||
self.utt2spk[utt] = spk
|
||||
elif spk2utt is not None:
|
||||
self.utt2spk = {}
|
||||
with io.open(spk2utt, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
spk, utts = line.rstrip().split(None, 1)
|
||||
for utt in utts.split():
|
||||
self.utt2spk[utt] = spk
|
||||
else:
|
||||
self.utt2spk = None
|
||||
|
||||
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
|
||||
# and the first vector contains the sum of feats and the second is
|
||||
# the sum of squares. The last value of the first, i.e. stats[0,-1],
|
||||
# is the number of samples for this statistics.
|
||||
self.bias = {}
|
||||
self.scale = {}
|
||||
for spk, stats in stats_dict.items():
|
||||
assert len(stats) == 2, stats.shape
|
||||
|
||||
count = stats[0, -1]
|
||||
|
||||
# If the feature has two or more dimensions
|
||||
if not (np.isscalar(count) or isinstance(count, (int, float))):
|
||||
# The first is only used
|
||||
count = count.flatten()[0]
|
||||
|
||||
mean = stats[0, :-1] / count
|
||||
# V(x) = E(x^2) - (E(x))^2
|
||||
var = stats[1, :-1] / count - mean * mean
|
||||
std = np.maximum(np.sqrt(var), std_floor)
|
||||
self.bias[spk] = -mean
|
||||
self.scale[spk] = 1 / std
|
||||
|
||||
def __repr__(self):
|
||||
return ("{name}(stats_file={stats_file}, "
|
||||
"norm_means={norm_means}, norm_vars={norm_vars}, "
|
||||
"reverse={reverse})".format(
|
||||
name=self.__class__.__name__,
|
||||
stats_file=self.stats_file,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars,
|
||||
reverse=self.reverse, ))
|
||||
|
||||
def __call__(self, x, uttid=None):
|
||||
if self.utt2spk is not None:
|
||||
spk = self.utt2spk[uttid]
|
||||
else:
|
||||
spk = uttid
|
||||
|
||||
if not self.reverse:
|
||||
# apply cmvn
|
||||
if self.norm_means:
|
||||
x = np.add(x, self.bias[spk])
|
||||
if self.norm_vars:
|
||||
x = np.multiply(x, self.scale[spk])
|
||||
|
||||
else:
|
||||
# apply reverse cmvn
|
||||
if self.norm_vars:
|
||||
x = np.divide(x, self.scale[spk])
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, self.bias[spk])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UtteranceCMVN():
|
||||
"Apply Utterance CMVN"
|
||||
|
||||
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.std_floor = std_floor
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
|
||||
name=self.__class__.__name__,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars, )
|
||||
|
||||
def __call__(self, x, uttid=None):
|
||||
# x: [Time, Dim]
|
||||
square_sums = (x**2).sum(axis=0)
|
||||
mean = x.mean(axis=0)
|
||||
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, mean)
|
||||
|
||||
if self.norm_vars:
|
||||
var = square_sums / x.shape[0] - mean**2
|
||||
std = np.maximum(np.sqrt(var), self.std_floor)
|
||||
x = np.divide(x, std)
|
||||
|
||||
return x
|
@ -0,0 +1,71 @@
|
||||
import inspect
|
||||
|
||||
from deepspeech.transform.transform_interface import TransformInterface
|
||||
from deepspeech.utils.check_kwargs import check_kwargs
|
||||
|
||||
|
||||
class FuncTrans(TransformInterface):
|
||||
"""Functional Transformation
|
||||
|
||||
WARNING:
|
||||
Builtin or C/C++ functions may not work properly
|
||||
because this class heavily depends on the `inspect` module.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> def foo_bar(x, a=1, b=2):
|
||||
... '''Foo bar
|
||||
... :param x: input
|
||||
... :param int a: default 1
|
||||
... :param int b: default 2
|
||||
... '''
|
||||
... return x + a - b
|
||||
|
||||
|
||||
>>> class FooBar(FuncTrans):
|
||||
... _func = foo_bar
|
||||
... __doc__ = foo_bar.__doc__
|
||||
"""
|
||||
|
||||
_func = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
check_kwargs(self.func, kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.func(x, **self.kwargs)
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
fname = cls._func.__name__.replace("_", "-")
|
||||
group = parser.add_argument_group(fname + " transformation setting")
|
||||
for k, v in cls.default_params().items():
|
||||
# TODO(karita): get help and choices from docstring?
|
||||
attr = k.replace("_", "-")
|
||||
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
|
||||
return parser
|
||||
|
||||
@property
|
||||
def func(self):
|
||||
return type(self)._func
|
||||
|
||||
@classmethod
|
||||
def default_params(cls):
|
||||
try:
|
||||
d = dict(inspect.signature(cls._func).parameters)
|
||||
except ValueError:
|
||||
d = dict()
|
||||
return {
|
||||
k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
params = self.default_params()
|
||||
params.update(**self.kwargs)
|
||||
ret = self.__class__.__name__ + "("
|
||||
if len(params) == 0:
|
||||
return ret + ")"
|
||||
for k, v in params.items():
|
||||
ret += "{}={}, ".format(k, v)
|
||||
return ret[:-2] + ")"
|
@ -0,0 +1,343 @@
|
||||
import librosa
|
||||
import numpy
|
||||
import scipy
|
||||
import soundfile
|
||||
|
||||
from deepspeech.io.reader import SoundHDF5File
|
||||
|
||||
class SpeedPerturbation():
|
||||
"""SpeedPerturbation
|
||||
|
||||
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
||||
and sox-speed just to resample the input,
|
||||
i.e pitch and tempo are changed both.
|
||||
|
||||
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
||||
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
||||
|
||||
Warning:
|
||||
This function is very slow because of resampling.
|
||||
I recommmend to apply speed-perturb outside the training using sox.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower=0.9,
|
||||
upper=1.1,
|
||||
utt2ratio=None,
|
||||
keep_length=True,
|
||||
res_type="kaiser_best",
|
||||
seed=None,
|
||||
):
|
||||
self.res_type = res_type
|
||||
self.keep_length = keep_length
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
self.utt2ratio = {}
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
self.utt2ratio = None
|
||||
# The ratio is given on runtime randomly
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
|
||||
self.__class__.__name__,
|
||||
self.lower,
|
||||
self.upper,
|
||||
self.keep_length,
|
||||
self.res_type,
|
||||
)
|
||||
else:
|
||||
return "{}({}, res_type={})".format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.res_type
|
||||
)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
# Note1: resample requires the sampling-rate of input and output,
|
||||
# but actually only the ratio is used.
|
||||
y = librosa.resample(x, ratio, 1, res_type=self.res_type)
|
||||
|
||||
if self.keep_length:
|
||||
diff = abs(len(x) - len(y))
|
||||
if len(y) > len(x):
|
||||
# Truncate noise
|
||||
y = y[diff // 2 : -((diff + 1) // 2)]
|
||||
elif len(y) < len(x):
|
||||
# Assume the time-axis is the first: (Time, Channel)
|
||||
pad_width = [(diff // 2, (diff + 1) // 2)] + [
|
||||
(0, 0) for _ in range(y.ndim - 1)
|
||||
]
|
||||
y = numpy.pad(
|
||||
y, pad_width=pad_width, constant_values=0, mode="constant"
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
class BandpassPerturbation():
|
||||
"""BandpassPerturbation
|
||||
|
||||
Randomly dropout along the frequency axis.
|
||||
|
||||
The original idea comes from the following:
|
||||
"randomly-selected frequency band was cut off under the constraint of
|
||||
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
|
||||
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
|
||||
everyday home environments using multiple microphone arrays;
|
||||
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)):
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
# x_stft: (Time, Channel, Freq)
|
||||
self.axes = axes
|
||||
|
||||
def __repr__(self):
|
||||
return "{}(lower={}, upper={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper
|
||||
)
|
||||
|
||||
def __call__(self, x_stft, uttid=None, train=True):
|
||||
if not train:
|
||||
return x_stft
|
||||
|
||||
if x_stft.ndim == 1:
|
||||
raise RuntimeError(
|
||||
"Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)"
|
||||
)
|
||||
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
|
||||
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
|
||||
|
||||
mask = self.state.randn(*shape) > ratio
|
||||
x_stft *= mask
|
||||
return x_stft
|
||||
|
||||
|
||||
class VolumePerturbation():
|
||||
def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None):
|
||||
self.dbunit = dbunit
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio = {}
|
||||
self.lower = None
|
||||
self.upper = None
|
||||
self.accept_uttid = True
|
||||
|
||||
with open(utt2ratio, "r") as f:
|
||||
for line in f:
|
||||
utt, ratio = line.rstrip().split(None, 1)
|
||||
ratio = float(ratio)
|
||||
self.utt2ratio[utt] = ratio
|
||||
else:
|
||||
# The ratio is given on runtime randomly
|
||||
self.utt2ratio = None
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper, self.dbunit
|
||||
)
|
||||
else:
|
||||
return '{}("{}", dbunit={})'.format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.dbunit
|
||||
)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
if self.accept_uttid:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
if self.dbunit:
|
||||
ratio = 10 ** (ratio / 20)
|
||||
return x * ratio
|
||||
|
||||
|
||||
class NoiseInjection():
|
||||
"""Add isotropic noise"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
utt2noise=None,
|
||||
lower=-20,
|
||||
upper=-5,
|
||||
utt2ratio=None,
|
||||
filetype="list",
|
||||
dbunit=True,
|
||||
seed=None,
|
||||
):
|
||||
self.utt2noise_file = utt2noise
|
||||
self.utt2ratio_file = utt2ratio
|
||||
self.filetype = filetype
|
||||
self.dbunit = dbunit
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.state = numpy.random.RandomState(seed)
|
||||
|
||||
if utt2ratio is not None:
|
||||
# Use the scheduled ratio for each utterances
|
||||
self.utt2ratio = {}
|
||||
with open(utt2noise, "r") as f:
|
||||
for line in f:
|
||||
utt, snr = line.rstrip().split(None, 1)
|
||||
snr = float(snr)
|
||||
self.utt2ratio[utt] = snr
|
||||
else:
|
||||
# The ratio is given on runtime randomly
|
||||
self.utt2ratio = None
|
||||
|
||||
if utt2noise is not None:
|
||||
self.utt2noise = {}
|
||||
if filetype == "list":
|
||||
with open(utt2noise, "r") as f:
|
||||
for line in f:
|
||||
utt, filename = line.rstrip().split(None, 1)
|
||||
signal, rate = soundfile.read(filename, dtype="int16")
|
||||
# Load all files in memory
|
||||
self.utt2noise[utt] = (signal, rate)
|
||||
|
||||
elif filetype == "sound.hdf5":
|
||||
self.utt2noise = SoundHDF5File(utt2noise, "r")
|
||||
else:
|
||||
raise ValueError(filetype)
|
||||
else:
|
||||
self.utt2noise = None
|
||||
|
||||
if utt2noise is not None and utt2ratio is not None:
|
||||
if set(self.utt2ratio) != set(self.utt2noise):
|
||||
raise RuntimeError(
|
||||
"The uttids mismatch between {} and {}".format(utt2ratio, utt2noise)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
if self.utt2ratio is None:
|
||||
return "{}(lower={}, upper={}, dbunit={})".format(
|
||||
self.__class__.__name__, self.lower, self.upper, self.dbunit
|
||||
)
|
||||
else:
|
||||
return '{}("{}", dbunit={})'.format(
|
||||
self.__class__.__name__, self.utt2ratio_file, self.dbunit
|
||||
)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
# 1. Get ratio of noise to signal in sound pressure level
|
||||
if uttid is not None and self.utt2ratio is not None:
|
||||
ratio = self.utt2ratio[uttid]
|
||||
else:
|
||||
ratio = self.state.uniform(self.lower, self.upper)
|
||||
|
||||
if self.dbunit:
|
||||
ratio = 10 ** (ratio / 20)
|
||||
scale = ratio * numpy.sqrt((x ** 2).mean())
|
||||
|
||||
# 2. Get noise
|
||||
if self.utt2noise is not None:
|
||||
# Get noise from the external source
|
||||
if uttid is not None:
|
||||
noise, rate = self.utt2noise[uttid]
|
||||
else:
|
||||
# Randomly select the noise source
|
||||
noise = self.state.choice(list(self.utt2noise.values()))
|
||||
# Normalize the level
|
||||
noise /= numpy.sqrt((noise ** 2).mean())
|
||||
|
||||
# Adjust the noise length
|
||||
diff = abs(len(x) - len(noise))
|
||||
offset = self.state.randint(0, diff)
|
||||
if len(noise) > len(x):
|
||||
# Truncate noise
|
||||
noise = noise[offset : -(diff - offset)]
|
||||
else:
|
||||
noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap")
|
||||
|
||||
else:
|
||||
# Generate white noise
|
||||
noise = self.state.normal(0, 1, x.shape)
|
||||
|
||||
# 3. Add noise to signal
|
||||
return x + noise * scale
|
||||
|
||||
|
||||
class RIRConvolve():
|
||||
def __init__(self, utt2rir, filetype="list"):
|
||||
self.utt2rir_file = utt2rir
|
||||
self.filetype = filetype
|
||||
|
||||
self.utt2rir = {}
|
||||
if filetype == "list":
|
||||
with open(utt2rir, "r") as f:
|
||||
for line in f:
|
||||
utt, filename = line.rstrip().split(None, 1)
|
||||
signal, rate = soundfile.read(filename, dtype="int16")
|
||||
self.utt2rir[utt] = (signal, rate)
|
||||
|
||||
elif filetype == "sound.hdf5":
|
||||
self.utt2rir = SoundHDF5File(utt2rir, "r")
|
||||
else:
|
||||
raise NotImplementedError(filetype)
|
||||
|
||||
def __repr__(self):
|
||||
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
|
||||
|
||||
def __call__(self, x, uttid=None, train=True):
|
||||
if not train:
|
||||
return x
|
||||
|
||||
x = x.astype(numpy.float32)
|
||||
|
||||
if x.ndim != 1:
|
||||
# Must be single channel
|
||||
raise RuntimeError(
|
||||
"Input x must be one dimensional array, but got {}".format(x.shape)
|
||||
)
|
||||
|
||||
rir, rate = self.utt2rir[uttid]
|
||||
if rir.ndim == 2:
|
||||
# FIXME(kamo): Use chainer.convolution_1d?
|
||||
# return [Time, Channel]
|
||||
return numpy.stack(
|
||||
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1
|
||||
)
|
||||
else:
|
||||
return scipy.convolve(x, rir, mode="same")
|
@ -0,0 +1,202 @@
|
||||
"""Spec Augment module for preprocessing i.e., data augmentation"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from PIL.Image import BICUBIC
|
||||
|
||||
from deepspeech.transform.functional import FuncTrans
|
||||
|
||||
|
||||
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
|
||||
"""time warp for spec augment
|
||||
|
||||
move random center frame by the random width ~ uniform(-window, window)
|
||||
:param numpy.ndarray x: spectrogram (time, freq)
|
||||
:param int max_time_warp: maximum time frames to warp
|
||||
:param bool inplace: overwrite x with the result
|
||||
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
|
||||
(slow, differentiable)
|
||||
:returns numpy.ndarray: time warped spectrogram (time, freq)
|
||||
"""
|
||||
window = max_time_warp
|
||||
if mode == "PIL":
|
||||
t = x.shape[0]
|
||||
if t - window <= window:
|
||||
return x
|
||||
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
|
||||
center = random.randrange(window, t - window)
|
||||
warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1
|
||||
|
||||
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
|
||||
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC)
|
||||
if inplace:
|
||||
x[:warped] = left
|
||||
x[warped:] = right
|
||||
return x
|
||||
return numpy.concatenate((left, right), 0)
|
||||
elif mode == "sparse_image_warp":
|
||||
import paddle
|
||||
|
||||
from espnet.utils import spec_augment
|
||||
|
||||
# TODO(karita): make this differentiable again
|
||||
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"unknown resize mode: "
|
||||
+ mode
|
||||
+ ", choose one from (PIL, sparse_image_warp)."
|
||||
)
|
||||
|
||||
|
||||
class TimeWarp(FuncTrans):
|
||||
_func = time_warp
|
||||
__doc__ = time_warp.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
|
||||
"""freq mask for spec agument
|
||||
|
||||
:param numpy.ndarray x: (time, freq)
|
||||
:param int n_mask: the number of masks
|
||||
:param bool inplace: overwrite
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
if inplace:
|
||||
cloned = x
|
||||
else:
|
||||
cloned = x.copy()
|
||||
|
||||
num_mel_channels = cloned.shape[1]
|
||||
fs = numpy.random.randint(0, F, size=(n_mask, 2))
|
||||
|
||||
for f, mask_end in fs:
|
||||
f_zero = random.randrange(0, num_mel_channels - f)
|
||||
mask_end += f_zero
|
||||
|
||||
# avoids randrange error if values are equal and range is empty
|
||||
if f_zero == f_zero + f:
|
||||
continue
|
||||
|
||||
if replace_with_zero:
|
||||
cloned[:, f_zero:mask_end] = 0
|
||||
else:
|
||||
cloned[:, f_zero:mask_end] = cloned.mean()
|
||||
return cloned
|
||||
|
||||
|
||||
class FreqMask(FuncTrans):
|
||||
_func = freq_mask
|
||||
__doc__ = freq_mask.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
|
||||
"""freq mask for spec agument
|
||||
|
||||
:param numpy.ndarray spec: (time, freq)
|
||||
:param int n_mask: the number of masks
|
||||
:param bool inplace: overwrite
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
if inplace:
|
||||
cloned = spec
|
||||
else:
|
||||
cloned = spec.copy()
|
||||
len_spectro = cloned.shape[0]
|
||||
ts = numpy.random.randint(0, T, size=(n_mask, 2))
|
||||
for t, mask_end in ts:
|
||||
# avoid randint range error
|
||||
if len_spectro - t <= 0:
|
||||
continue
|
||||
t_zero = random.randrange(0, len_spectro - t)
|
||||
|
||||
# avoids randrange error if values are equal and range is empty
|
||||
if t_zero == t_zero + t:
|
||||
continue
|
||||
|
||||
mask_end += t_zero
|
||||
if replace_with_zero:
|
||||
cloned[t_zero:mask_end] = 0
|
||||
else:
|
||||
cloned[t_zero:mask_end] = cloned.mean()
|
||||
return cloned
|
||||
|
||||
|
||||
class TimeMask(FuncTrans):
|
||||
_func = time_mask
|
||||
__doc__ = time_mask.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def spec_augment(
|
||||
x,
|
||||
resize_mode="PIL",
|
||||
max_time_warp=80,
|
||||
max_freq_width=27,
|
||||
n_freq_mask=2,
|
||||
max_time_width=100,
|
||||
n_time_mask=2,
|
||||
inplace=True,
|
||||
replace_with_zero=True,
|
||||
):
|
||||
"""spec agument
|
||||
|
||||
apply random time warping and time/freq masking
|
||||
default setting is based on LD (Librispeech double) in Table 2
|
||||
https://arxiv.org/pdf/1904.08779.pdf
|
||||
|
||||
:param numpy.ndarray x: (time, freq)
|
||||
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
|
||||
(slow, differentiable)
|
||||
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
|
||||
:param int freq_mask_width: maximum width of the random freq mask (F)
|
||||
:param int n_freq_mask: the number of the random freq mask (m_F)
|
||||
:param int time_mask_width: maximum width of the random time mask (T)
|
||||
:param int n_time_mask: the number of the random time mask (m_T)
|
||||
:param bool inplace: overwrite intermediate array
|
||||
:param bool replace_with_zero: pad zero on mask if true else use mean
|
||||
"""
|
||||
assert isinstance(x, numpy.ndarray)
|
||||
assert x.ndim == 2
|
||||
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
|
||||
x = freq_mask(
|
||||
x,
|
||||
max_freq_width,
|
||||
n_freq_mask,
|
||||
inplace=inplace,
|
||||
replace_with_zero=replace_with_zero,
|
||||
)
|
||||
x = time_mask(
|
||||
x,
|
||||
max_time_width,
|
||||
n_time_mask,
|
||||
inplace=inplace,
|
||||
replace_with_zero=replace_with_zero,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class SpecAugment(FuncTrans):
|
||||
_func = spec_augment
|
||||
__doc__ = spec_augment.__doc__
|
||||
|
||||
def __call__(self, x, train):
|
||||
if not train:
|
||||
return x
|
||||
return super().__call__(x)
|
@ -0,0 +1,307 @@
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
|
||||
def stft(
|
||||
x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect"
|
||||
):
|
||||
# x: [Time, Channel]
|
||||
if x.ndim == 1:
|
||||
single_channel = True
|
||||
# x: [Time] -> [Time, Channel]
|
||||
x = x[:, None]
|
||||
else:
|
||||
single_channel = False
|
||||
x = x.astype(np.float32)
|
||||
|
||||
# FIXME(kamo): librosa.stft can't use multi-channel?
|
||||
# x: [Time, Channel, Freq]
|
||||
x = np.stack(
|
||||
[
|
||||
librosa.stft(
|
||||
x[:, ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center,
|
||||
pad_mode=pad_mode,
|
||||
).T
|
||||
for ch in range(x.shape[1])
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
if single_channel:
|
||||
# x: [Time, Channel, Freq] -> [Time, Freq]
|
||||
x = x[:, 0]
|
||||
return x
|
||||
|
||||
|
||||
def istft(x, n_shift, win_length=None, window="hann", center=True):
|
||||
# x: [Time, Channel, Freq]
|
||||
if x.ndim == 2:
|
||||
single_channel = True
|
||||
# x: [Time, Freq] -> [Time, Channel, Freq]
|
||||
x = x[:, None, :]
|
||||
else:
|
||||
single_channel = False
|
||||
|
||||
# x: [Time, Channel]
|
||||
x = np.stack(
|
||||
[
|
||||
librosa.istft(
|
||||
x[:, ch].T, # [Time, Freq] -> [Freq, Time]
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=center,
|
||||
)
|
||||
for ch in range(x.shape[1])
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
if single_channel:
|
||||
# x: [Time, Channel] -> [Time]
|
||||
x = x[:, 0]
|
||||
return x
|
||||
|
||||
|
||||
def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||
# x_stft: (Time, Channel, Freq) or (Time, Freq)
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
|
||||
# spc: (Time, Channel, Freq) or (Time, Freq)
|
||||
spc = np.abs(x_stft)
|
||||
# mel_basis: (Mel_freq, Freq)
|
||||
mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax)
|
||||
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
|
||||
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
||||
|
||||
return lmspc
|
||||
|
||||
|
||||
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
|
||||
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
|
||||
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
|
||||
return spc
|
||||
|
||||
|
||||
def logmelspectrogram(
|
||||
x,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10,
|
||||
pad_mode="reflect",
|
||||
):
|
||||
# stft: (Time, Channel, Freq) or (Time, Freq)
|
||||
x_stft = stft(
|
||||
x,
|
||||
n_fft=n_fft,
|
||||
n_shift=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
return stft2logmelspectrogram(
|
||||
x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps
|
||||
)
|
||||
|
||||
|
||||
class Spectrogram():
|
||||
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return spectrogram(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
)
|
||||
|
||||
|
||||
class LogMelSpectrogram():
|
||||
def __init__(
|
||||
self,
|
||||
fs,
|
||||
n_mels,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10,
|
||||
):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.eps = eps
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"n_shift={n_shift}, win_length={win_length}, window={window}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return logmelspectrogram(
|
||||
x,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
)
|
||||
|
||||
|
||||
class Stft2LogMelSpectrogram():
|
||||
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
self.n_fft = n_fft
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.eps = eps
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
|
||||
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
|
||||
name=self.__class__.__name__,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
eps=self.eps,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return stft2logmelspectrogram(
|
||||
x,
|
||||
fs=self.fs,
|
||||
n_mels=self.n_mels,
|
||||
n_fft=self.n_fft,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax,
|
||||
)
|
||||
|
||||
|
||||
class Stft():
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
n_shift,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
):
|
||||
self.n_fft = n_fft
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window},"
|
||||
"center={center}, pad_mode={pad_mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_fft=self.n_fft,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return stft(
|
||||
x,
|
||||
self.n_fft,
|
||||
self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode,
|
||||
)
|
||||
|
||||
|
||||
class IStft():
|
||||
def __init__(self, n_shift, win_length=None, window="hann", center=True):
|
||||
self.n_shift = n_shift
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.center = center
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(n_shift={n_shift}, "
|
||||
"win_length={win_length}, window={window},"
|
||||
"center={center})".format(
|
||||
name=self.__class__.__name__,
|
||||
n_shift=self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return istft(
|
||||
x,
|
||||
self.n_shift,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
)
|
@ -0,0 +1,20 @@
|
||||
# TODO(karita): add this to all the transform impl.
|
||||
class TransformInterface:
|
||||
"""Transform Interface"""
|
||||
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError("__call__ method is not implemented")
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
return parser
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
class Identity(TransformInterface):
|
||||
"""Identity Function"""
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
@ -0,0 +1,149 @@
|
||||
"""Transformation module."""
|
||||
from collections.abc import Sequence
|
||||
from collections import OrderedDict
|
||||
import copy
|
||||
from inspect import signature
|
||||
import io
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
|
||||
from deepspeech.utils.dynamic_import import dynamic_import
|
||||
|
||||
|
||||
# TODO(karita): inherit TransformInterface
|
||||
# TODO(karita): register cmd arguments in asr_train.py
|
||||
import_alias = dict(
|
||||
identity="deepspeech.transform.transform_interface:Identity",
|
||||
time_warp="deepspeech.transform.spec_augment:TimeWarp",
|
||||
time_mask="deepspeech.transform.spec_augment:TimeMask",
|
||||
freq_mask="deepspeech.transform.spec_augment:FreqMask",
|
||||
spec_augment="deepspeech.transform.spec_augment:SpecAugment",
|
||||
speed_perturbation="deepspeech.transform.perturb:SpeedPerturbation",
|
||||
volume_perturbation="deepspeech.transform.perturb:VolumePerturbation",
|
||||
noise_injection="deepspeech.transform.perturb:NoiseInjection",
|
||||
bandpass_perturbation="deepspeech.transform.perturb:BandpassPerturbation",
|
||||
rir_convolve="deepspeech.transform.perturb:RIRConvolve",
|
||||
delta="deepspeech.transform.add_deltas:AddDeltas",
|
||||
cmvn="deepspeech.transform.cmvn:CMVN",
|
||||
utterance_cmvn="deepspeech.transform.cmvn:UtteranceCMVN",
|
||||
fbank="deepspeech.transform.spectrogram:LogMelSpectrogram",
|
||||
spectrogram="deepspeech.transform.spectrogram:Spectrogram",
|
||||
stft="deepspeech.transform.spectrogram:Stft",
|
||||
istft="deepspeech.transform.spectrogram:IStft",
|
||||
stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram",
|
||||
wpe="deepspeech.transform.wpe:WPE",
|
||||
channel_selector="deepspeech.transform.channel_selector:ChannelSelector",
|
||||
)
|
||||
|
||||
|
||||
class Transformation():
|
||||
"""Apply some functions to the mini-batch
|
||||
|
||||
Examples:
|
||||
>>> kwargs = {"process": [{"type": "fbank",
|
||||
... "n_mels": 80,
|
||||
... "fs": 16000},
|
||||
... {"type": "cmvn",
|
||||
... "stats": "data/train/cmvn.ark",
|
||||
... "norm_vars": True},
|
||||
... {"type": "delta", "window": 2, "order": 2}]}
|
||||
>>> transform = Transformation(kwargs)
|
||||
>>> bs = 10
|
||||
>>> xs = [np.random.randn(100, 80).astype(np.float32)
|
||||
... for _ in range(bs)]
|
||||
>>> xs = transform(xs)
|
||||
"""
|
||||
|
||||
def __init__(self, conffile=None):
|
||||
if conffile is not None:
|
||||
if isinstance(conffile, dict):
|
||||
self.conf = copy.deepcopy(conffile)
|
||||
else:
|
||||
with io.open(conffile, encoding="utf-8") as f:
|
||||
self.conf = yaml.safe_load(f)
|
||||
assert isinstance(self.conf, dict), type(self.conf)
|
||||
else:
|
||||
self.conf = {"mode": "sequential", "process": []}
|
||||
|
||||
self.functions = OrderedDict()
|
||||
if self.conf.get("mode", "sequential") == "sequential":
|
||||
for idx, process in enumerate(self.conf["process"]):
|
||||
assert isinstance(process, dict), type(process)
|
||||
opts = dict(process)
|
||||
process_type = opts.pop("type")
|
||||
class_obj = dynamic_import(process_type, import_alias)
|
||||
# TODO(karita): assert issubclass(class_obj, TransformInterface)
|
||||
try:
|
||||
self.functions[idx] = class_obj(**opts)
|
||||
except TypeError:
|
||||
try:
|
||||
signa = signature(class_obj)
|
||||
except ValueError:
|
||||
# Some function, e.g. built-in function, are failed
|
||||
pass
|
||||
else:
|
||||
logging.error(
|
||||
"Expected signature: {}({})".format(
|
||||
class_obj.__name__, signa
|
||||
)
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supporting mode={}".format(self.conf["mode"])
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
rep = "\n" + "\n".join(
|
||||
" {}: {}".format(k, v) for k, v in self.functions.items()
|
||||
)
|
||||
return "{}({})".format(self.__class__.__name__, rep)
|
||||
|
||||
def __call__(self, xs, uttid_list=None, **kwargs):
|
||||
"""Return new mini-batch
|
||||
|
||||
:param Union[Sequence[np.ndarray], np.ndarray] xs:
|
||||
:param Union[Sequence[str], str] uttid_list:
|
||||
:return: batch:
|
||||
:rtype: List[np.ndarray]
|
||||
"""
|
||||
if not isinstance(xs, Sequence):
|
||||
is_batch = False
|
||||
xs = [xs]
|
||||
else:
|
||||
is_batch = True
|
||||
|
||||
if isinstance(uttid_list, str):
|
||||
uttid_list = [uttid_list for _ in range(len(xs))]
|
||||
|
||||
if self.conf.get("mode", "sequential") == "sequential":
|
||||
for idx in range(len(self.conf["process"])):
|
||||
func = self.functions[idx]
|
||||
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
|
||||
# Derive only the args which the func has
|
||||
try:
|
||||
param = signature(func).parameters
|
||||
except ValueError:
|
||||
# Some function, e.g. built-in function, are failed
|
||||
param = {}
|
||||
_kwargs = {k: v for k, v in kwargs.items() if k in param}
|
||||
try:
|
||||
if uttid_list is not None and "uttid" in param:
|
||||
xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)]
|
||||
else:
|
||||
xs = [func(x, **_kwargs) for x in xs]
|
||||
except Exception:
|
||||
logging.fatal(
|
||||
"Catch a exception from {}th func: {}".format(idx, func)
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Not supporting mode={}".format(self.conf["mode"])
|
||||
)
|
||||
|
||||
if is_batch:
|
||||
return xs
|
||||
else:
|
||||
return xs[0]
|
@ -0,0 +1,45 @@
|
||||
from nara_wpe.wpe import wpe
|
||||
|
||||
|
||||
class WPE(object):
|
||||
def __init__(
|
||||
self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full"
|
||||
):
|
||||
self.taps = taps
|
||||
self.delay = delay
|
||||
self.iterations = iterations
|
||||
self.psd_context = psd_context
|
||||
self.statistics_mode = statistics_mode
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}(taps={taps}, delay={delay}"
|
||||
"iterations={iterations}, psd_context={psd_context}, "
|
||||
"statistics_mode={statistics_mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
iterations=self.iterations,
|
||||
psd_context=self.psd_context,
|
||||
statistics_mode=self.statistics_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, xs):
|
||||
"""Return enhanced
|
||||
|
||||
:param np.ndarray xs: (Time, Channel, Frequency)
|
||||
:return: enhanced_xs
|
||||
:rtype: np.ndarray
|
||||
|
||||
"""
|
||||
# nara_wpe.wpe: (F, C, T)
|
||||
xs = wpe(
|
||||
xs.transpose((2, 1, 0)),
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
iterations=self.iterations,
|
||||
psd_context=self.psd_context,
|
||||
statistics_mode=self.statistics_mode,
|
||||
)
|
||||
return xs.transpose(2, 1, 0)
|
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["label_smoothing_dist"]
|
||||
|
||||
|
||||
# TODO(takaaki-hori): add different smoothing methods
|
||||
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||
"""Obtain label distribution for loss smoothing.
|
||||
|
||||
:param odim:
|
||||
:param lsm_type:
|
||||
:param blank:
|
||||
:param transcript:
|
||||
:return:
|
||||
"""
|
||||
if transcript is not None:
|
||||
with open(transcript, "rb") as f:
|
||||
trans_json = json.load(f)["utts"]
|
||||
|
||||
if lsm_type == "unigram":
|
||||
assert transcript is not None, (
|
||||
"transcript is required for %s label smoothing" % lsm_type)
|
||||
labelcount = np.zeros(odim)
|
||||
for k, v in trans_json.items():
|
||||
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
|
||||
# to avoid an error when there is no text in an uttrance
|
||||
if len(ids) > 0:
|
||||
labelcount[ids] += 1
|
||||
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||
labelcount[labelcount == 0] = 1 # flooring
|
||||
labelcount[blank] = 0 # remove counts for blank
|
||||
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||
else:
|
||||
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
|
||||
sys.exit()
|
||||
|
||||
return labeldist
|
@ -0,0 +1,20 @@
|
||||
import inspect
|
||||
|
||||
|
||||
def check_kwargs(func, kwargs, name=None):
|
||||
"""check kwargs are valid for func
|
||||
|
||||
If kwargs are invalid, raise TypeError as same as python default
|
||||
:param function func: function to be validated
|
||||
:param dict kwargs: keyword arguments for func
|
||||
:param str name: name used in TypeError (default is func name)
|
||||
"""
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
except ValueError:
|
||||
return
|
||||
if name is None:
|
||||
name = func.__name__
|
||||
for k in kwargs.keys():
|
||||
if k not in params:
|
||||
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'")
|
@ -0,0 +1,241 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import h5py
|
||||
import kaldiio
|
||||
import soundfile
|
||||
|
||||
from deepspeech.io.reader import SoundHDF5File
|
||||
|
||||
|
||||
def file_reader_helper(
|
||||
rspecifier: str,
|
||||
filetype: str="mat",
|
||||
return_shape: bool=False,
|
||||
segments: str=None, ):
|
||||
"""Read uttid and array in kaldi style
|
||||
|
||||
This function might be a bit confusing as "ark" is used
|
||||
for HDF5 to imitate "kaldi-rspecifier".
|
||||
|
||||
Args:
|
||||
rspecifier: Give as "ark:feats.ark" or "scp:feats.scp"
|
||||
filetype: "mat" is kaldi-martix, "hdf5": HDF5
|
||||
return_shape: Return the shape of the matrix,
|
||||
instead of the matrix. This can reduce IO cost for HDF5.
|
||||
segments (str): The file format is
|
||||
"<segment-id> <recording-id> <start-time> <end-time>\n"
|
||||
"e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5\n"
|
||||
Returns:
|
||||
Generator[Tuple[str, np.ndarray], None, None]:
|
||||
|
||||
Examples:
|
||||
Read from kaldi-matrix ark file:
|
||||
|
||||
>>> for u, array in file_reader_helper('ark:feats.ark', 'mat'):
|
||||
... array
|
||||
|
||||
Read from HDF5 file:
|
||||
|
||||
>>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'):
|
||||
... array
|
||||
|
||||
"""
|
||||
if filetype == "mat":
|
||||
return KaldiReader(
|
||||
rspecifier, return_shape=return_shape, segments=segments)
|
||||
elif filetype == "hdf5":
|
||||
return HDF5Reader(rspecifier, return_shape=return_shape)
|
||||
elif filetype == "sound.hdf5":
|
||||
return SoundHDF5Reader(rspecifier, return_shape=return_shape)
|
||||
elif filetype == "sound":
|
||||
return SoundReader(rspecifier, return_shape=return_shape)
|
||||
else:
|
||||
raise NotImplementedError(f"filetype={filetype}")
|
||||
|
||||
|
||||
class KaldiReader:
|
||||
def __init__(self, rspecifier, return_shape=False, segments=None):
|
||||
self.rspecifier = rspecifier
|
||||
self.return_shape = return_shape
|
||||
self.segments = segments
|
||||
|
||||
def __iter__(self):
|
||||
with kaldiio.ReadHelper(
|
||||
self.rspecifier, segments=self.segments) as reader:
|
||||
for key, array in reader:
|
||||
if self.return_shape:
|
||||
array = array.shape
|
||||
yield key, array
|
||||
|
||||
|
||||
class HDF5Reader:
|
||||
def __init__(self, rspecifier, return_shape=False):
|
||||
if ":" not in rspecifier:
|
||||
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'.
|
||||
format(self.rspecifier))
|
||||
self.rspecifier = rspecifier
|
||||
self.ark_or_scp, self.filepath = self.rspecifier.split(":", 1)
|
||||
if self.ark_or_scp not in ["ark", "scp"]:
|
||||
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
|
||||
|
||||
self.return_shape = return_shape
|
||||
|
||||
def __iter__(self):
|
||||
if self.ark_or_scp == "scp":
|
||||
hdf5_dict = {}
|
||||
with open(self.filepath, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
key, value = line.rstrip().split(None, 1)
|
||||
|
||||
if ":" not in value:
|
||||
raise RuntimeError(
|
||||
"scp file for hdf5 should be like: "
|
||||
'"uttid filepath.h5:key": {}({})'.format(
|
||||
line, self.filepath))
|
||||
path, h5_key = value.split(":", 1)
|
||||
|
||||
hdf5_file = hdf5_dict.get(path)
|
||||
if hdf5_file is None:
|
||||
try:
|
||||
hdf5_file = h5py.File(path, "r")
|
||||
except Exception:
|
||||
logging.error("Error when loading {}".format(path))
|
||||
raise
|
||||
hdf5_dict[path] = hdf5_file
|
||||
|
||||
try:
|
||||
data = hdf5_file[h5_key]
|
||||
except Exception:
|
||||
logging.error("Error when loading {} with key={}".
|
||||
format(path, h5_key))
|
||||
raise
|
||||
|
||||
if self.return_shape:
|
||||
yield key, data.shape
|
||||
else:
|
||||
yield key, data[()]
|
||||
|
||||
# Closing all files
|
||||
for k in hdf5_dict:
|
||||
try:
|
||||
hdf5_dict[k].close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
if self.filepath == "-":
|
||||
# Required h5py>=2.9
|
||||
filepath = io.BytesIO(sys.stdin.buffer.read())
|
||||
else:
|
||||
filepath = self.filepath
|
||||
with h5py.File(filepath, "r") as f:
|
||||
for key in f:
|
||||
if self.return_shape:
|
||||
yield key, f[key].shape
|
||||
else:
|
||||
yield key, f[key][()]
|
||||
|
||||
|
||||
class SoundHDF5Reader:
|
||||
def __init__(self, rspecifier, return_shape=False):
|
||||
if ":" not in rspecifier:
|
||||
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'.
|
||||
format(rspecifier))
|
||||
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
|
||||
if self.ark_or_scp not in ["ark", "scp"]:
|
||||
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
|
||||
self.return_shape = return_shape
|
||||
|
||||
def __iter__(self):
|
||||
if self.ark_or_scp == "scp":
|
||||
hdf5_dict = {}
|
||||
with open(self.filepath, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
key, value = line.rstrip().split(None, 1)
|
||||
|
||||
if ":" not in value:
|
||||
raise RuntimeError(
|
||||
"scp file for hdf5 should be like: "
|
||||
'"uttid filepath.h5:key": {}({})'.format(
|
||||
line, self.filepath))
|
||||
path, h5_key = value.split(":", 1)
|
||||
|
||||
hdf5_file = hdf5_dict.get(path)
|
||||
if hdf5_file is None:
|
||||
try:
|
||||
hdf5_file = SoundHDF5File(path, "r")
|
||||
except Exception:
|
||||
logging.error("Error when loading {}".format(path))
|
||||
raise
|
||||
hdf5_dict[path] = hdf5_file
|
||||
|
||||
try:
|
||||
data = hdf5_file[h5_key]
|
||||
except Exception:
|
||||
logging.error("Error when loading {} with key={}".
|
||||
format(path, h5_key))
|
||||
raise
|
||||
|
||||
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
|
||||
# (soundfile style -> scipy style)
|
||||
array, rate = data
|
||||
if self.return_shape:
|
||||
array = array.shape
|
||||
yield key, (rate, array)
|
||||
|
||||
# Closing all files
|
||||
for k in hdf5_dict:
|
||||
try:
|
||||
hdf5_dict[k].close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
if self.filepath == "-":
|
||||
# Required h5py>=2.9
|
||||
filepath = io.BytesIO(sys.stdin.buffer.read())
|
||||
else:
|
||||
filepath = self.filepath
|
||||
for key, (a, r) in SoundHDF5File(filepath, "r").items():
|
||||
if self.return_shape:
|
||||
a = a.shape
|
||||
yield key, (r, a)
|
||||
|
||||
|
||||
class SoundReader:
|
||||
def __init__(self, rspecifier, return_shape=False):
|
||||
if ":" not in rspecifier:
|
||||
raise ValueError('Give "rspecifier" such as "scp:some.scp: {}"'.
|
||||
format(rspecifier))
|
||||
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
|
||||
if self.ark_or_scp != "scp":
|
||||
raise ValueError('Only supporting "scp" for sound file: {}'.format(
|
||||
self.ark_or_scp))
|
||||
self.return_shape = return_shape
|
||||
|
||||
def __iter__(self):
|
||||
with open(self.filepath, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
key, sound_file_path = line.rstrip().split(None, 1)
|
||||
# Assume PCM16
|
||||
array, rate = soundfile.read(sound_file_path, dtype="int16")
|
||||
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
|
||||
# (soundfile style -> scipy style)
|
||||
if self.return_shape:
|
||||
array = array.shape
|
||||
yield key, (rate, array)
|
@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from distutils.util import strtobool as dist_strtobool
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
def strtobool(x):
|
||||
# distutils.util.strtobool returns integer, but it's confusing,
|
||||
return bool(dist_strtobool(x))
|
||||
|
||||
|
||||
def get_commandline_args():
|
||||
extra_chars = [
|
||||
" ",
|
||||
";",
|
||||
"&",
|
||||
"(",
|
||||
")",
|
||||
"|",
|
||||
"^",
|
||||
"<",
|
||||
">",
|
||||
"?",
|
||||
"*",
|
||||
"[",
|
||||
"]",
|
||||
"$",
|
||||
"`",
|
||||
'"',
|
||||
"\\",
|
||||
"!",
|
||||
"{",
|
||||
"}",
|
||||
]
|
||||
|
||||
# Escape the extra characters for shell
|
||||
argv = [
|
||||
arg.replace("'", "'\\''") if all(char not in arg
|
||||
for char in extra_chars) else
|
||||
"'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv
|
||||
]
|
||||
|
||||
return sys.executable + " " + " ".join(argv)
|
||||
|
||||
|
||||
def is_scipy_wav_style(value):
|
||||
# If Tuple[int, numpy.ndarray] or not
|
||||
return (isinstance(value, Sequence) and len(value) == 2 and
|
||||
isinstance(value[0], int) and isinstance(value[1], numpy.ndarray))
|
||||
|
||||
|
||||
def assert_scipy_wav_style(value):
|
||||
assert is_scipy_wav_style(
|
||||
value), "Must be Tuple[int, numpy.ndarray], but got {}".format(
|
||||
type(value) if not isinstance(value, Sequence) else "{}[{}]".format(
|
||||
type(value), ", ".join(str(type(v)) for v in value)))
|
@ -0,0 +1,293 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import h5py
|
||||
import kaldiio
|
||||
import numpy
|
||||
import soundfile
|
||||
|
||||
from deepspeech.io.reader import SoundHDF5File
|
||||
from deepspeech.utils.cli_utils import assert_scipy_wav_style
|
||||
|
||||
|
||||
def file_writer_helper(
|
||||
wspecifier: str,
|
||||
filetype: str="mat",
|
||||
write_num_frames: str=None,
|
||||
compress: bool=False,
|
||||
compression_method: int=2,
|
||||
pcm_format: str="wav", ):
|
||||
"""Write matrices in kaldi style
|
||||
|
||||
Args:
|
||||
wspecifier: e.g. ark,scp:out.ark,out.scp
|
||||
filetype: "mat" is kaldi-martix, "hdf5": HDF5
|
||||
write_num_frames: e.g. 'ark,t:num_frames.txt'
|
||||
compress: Compress or not
|
||||
compression_method: Specify compression level
|
||||
|
||||
Write in kaldi-matrix-ark with "kaldi-scp" file:
|
||||
|
||||
>>> with file_writer_helper('ark,scp:out.ark,out.scp') as f:
|
||||
>>> f['uttid'] = array
|
||||
|
||||
This "scp" has the following format:
|
||||
|
||||
uttidA out.ark:1234
|
||||
uttidB out.ark:2222
|
||||
|
||||
where, 1234 and 2222 points the strating byte address of the matrix.
|
||||
(For detail, see official documentation of Kaldi)
|
||||
|
||||
Write in HDF5 with "scp" file:
|
||||
|
||||
>>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f:
|
||||
>>> f['uttid'] = array
|
||||
|
||||
This "scp" file is created as:
|
||||
|
||||
uttidA out.h5:uttidA
|
||||
uttidB out.h5:uttidB
|
||||
|
||||
HDF5 can be, unlike "kaldi-ark", accessed to any keys,
|
||||
so originally "scp" is not required for random-reading.
|
||||
Nevertheless we create "scp" for HDF5 because it is useful
|
||||
for some use-case. e.g. Concatenation, Splitting.
|
||||
|
||||
"""
|
||||
if filetype == "mat":
|
||||
return KaldiWriter(
|
||||
wspecifier,
|
||||
write_num_frames=write_num_frames,
|
||||
compress=compress,
|
||||
compression_method=compression_method, )
|
||||
elif filetype == "hdf5":
|
||||
return HDF5Writer(
|
||||
wspecifier, write_num_frames=write_num_frames, compress=compress)
|
||||
elif filetype == "sound.hdf5":
|
||||
return SoundHDF5Writer(
|
||||
wspecifier,
|
||||
write_num_frames=write_num_frames,
|
||||
pcm_format=pcm_format)
|
||||
elif filetype == "sound":
|
||||
return SoundWriter(
|
||||
wspecifier,
|
||||
write_num_frames=write_num_frames,
|
||||
pcm_format=pcm_format)
|
||||
else:
|
||||
raise NotImplementedError(f"filetype={filetype}")
|
||||
|
||||
|
||||
class BaseWriter:
|
||||
def __setitem__(self, key, value):
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.writer.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.writer_scp is not None:
|
||||
try:
|
||||
self.writer_scp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.writer_nframe is not None:
|
||||
try:
|
||||
self.writer_nframe.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_num_frames_writer(write_num_frames: str):
|
||||
"""get_num_frames_writer
|
||||
|
||||
Examples:
|
||||
>>> get_num_frames_writer('ark,t:num_frames.txt')
|
||||
"""
|
||||
if write_num_frames is not None:
|
||||
if ":" not in write_num_frames:
|
||||
raise ValueError('Must include ":", write_num_frames={}'.format(
|
||||
write_num_frames))
|
||||
|
||||
nframes_type, nframes_file = write_num_frames.split(":", 1)
|
||||
if nframes_type != "ark,t":
|
||||
raise ValueError("Only supporting text mode. "
|
||||
"e.g. --write-num-frames=ark,t:foo.txt :"
|
||||
"{}".format(nframes_type))
|
||||
|
||||
return open(nframes_file, "w", encoding="utf-8")
|
||||
|
||||
|
||||
class KaldiWriter(BaseWriter):
|
||||
def __init__(self,
|
||||
wspecifier,
|
||||
write_num_frames=None,
|
||||
compress=False,
|
||||
compression_method=2):
|
||||
if compress:
|
||||
self.writer = kaldiio.WriteHelper(
|
||||
wspecifier, compression_method=compression_method)
|
||||
else:
|
||||
self.writer = kaldiio.WriteHelper(wspecifier)
|
||||
self.writer_scp = None
|
||||
if write_num_frames is not None:
|
||||
self.writer_nframe = get_num_frames_writer(write_num_frames)
|
||||
else:
|
||||
self.writer_nframe = None
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.writer[key] = value
|
||||
if self.writer_nframe is not None:
|
||||
self.writer_nframe.write(f"{key} {len(value)}\n")
|
||||
|
||||
|
||||
def parse_wspecifier(wspecifier: str) -> Dict[str, str]:
|
||||
"""Parse wspecifier to dict
|
||||
|
||||
Examples:
|
||||
>>> parse_wspecifier('ark,scp:out.ark,out.scp')
|
||||
{'ark': 'out.ark', 'scp': 'out.scp'}
|
||||
|
||||
"""
|
||||
ark_scp, filepath = wspecifier.split(":", 1)
|
||||
if ark_scp not in ["ark", "scp,ark", "ark,scp"]:
|
||||
raise ValueError("{} is not allowed: {}".format(ark_scp, wspecifier))
|
||||
ark_scps = ark_scp.split(",")
|
||||
filepaths = filepath.split(",")
|
||||
if len(ark_scps) != len(filepaths):
|
||||
raise ValueError("Mismatch: {} and {}".format(ark_scp, filepath))
|
||||
spec_dict = dict(zip(ark_scps, filepaths))
|
||||
return spec_dict
|
||||
|
||||
|
||||
class HDF5Writer(BaseWriter):
|
||||
"""HDF5Writer
|
||||
|
||||
Examples:
|
||||
>>> with HDF5Writer('ark:out.h5', compress=True) as f:
|
||||
... f['key'] = array
|
||||
"""
|
||||
|
||||
def __init__(self, wspecifier, write_num_frames=None, compress=False):
|
||||
spec_dict = parse_wspecifier(wspecifier)
|
||||
self.filename = spec_dict["ark"]
|
||||
|
||||
if compress:
|
||||
self.kwargs = {"compression": "gzip"}
|
||||
else:
|
||||
self.kwargs = {}
|
||||
self.writer = h5py.File(spec_dict["ark"], "w")
|
||||
if "scp" in spec_dict:
|
||||
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
|
||||
else:
|
||||
self.writer_scp = None
|
||||
if write_num_frames is not None:
|
||||
self.writer_nframe = get_num_frames_writer(write_num_frames)
|
||||
else:
|
||||
self.writer_nframe = None
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.writer.create_dataset(key, data=value, **self.kwargs)
|
||||
|
||||
if self.writer_scp is not None:
|
||||
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
|
||||
if self.writer_nframe is not None:
|
||||
self.writer_nframe.write(f"{key} {len(value)}\n")
|
||||
|
||||
|
||||
class SoundHDF5Writer(BaseWriter):
|
||||
"""SoundHDF5Writer
|
||||
|
||||
Examples:
|
||||
>>> fs = 16000
|
||||
>>> with SoundHDF5Writer('ark:out.h5') as f:
|
||||
... f['key'] = fs, array
|
||||
"""
|
||||
|
||||
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
|
||||
self.pcm_format = pcm_format
|
||||
spec_dict = parse_wspecifier(wspecifier)
|
||||
self.filename = spec_dict["ark"]
|
||||
self.writer = SoundHDF5File(
|
||||
spec_dict["ark"], "w", format=self.pcm_format)
|
||||
if "scp" in spec_dict:
|
||||
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
|
||||
else:
|
||||
self.writer_scp = None
|
||||
if write_num_frames is not None:
|
||||
self.writer_nframe = get_num_frames_writer(write_num_frames)
|
||||
else:
|
||||
self.writer_nframe = None
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
assert_scipy_wav_style(value)
|
||||
# Change Tuple[int, ndarray] -> Tuple[ndarray, int]
|
||||
# (scipy style -> soundfile style)
|
||||
value = (value[1], value[0])
|
||||
self.writer.create_dataset(key, data=value)
|
||||
|
||||
if self.writer_scp is not None:
|
||||
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
|
||||
if self.writer_nframe is not None:
|
||||
self.writer_nframe.write(f"{key} {len(value[0])}\n")
|
||||
|
||||
|
||||
class SoundWriter(BaseWriter):
|
||||
"""SoundWriter
|
||||
|
||||
Examples:
|
||||
>>> fs = 16000
|
||||
>>> with SoundWriter('ark,scp:outdir,out.scp') as f:
|
||||
... f['key'] = fs, array
|
||||
"""
|
||||
|
||||
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
|
||||
self.pcm_format = pcm_format
|
||||
spec_dict = parse_wspecifier(wspecifier)
|
||||
# e.g. ark,scp:dirname,wav.scp
|
||||
# -> The wave files are found in dirname/*.wav
|
||||
self.dirname = spec_dict["ark"]
|
||||
Path(self.dirname).mkdir(parents=True, exist_ok=True)
|
||||
self.writer = None
|
||||
|
||||
if "scp" in spec_dict:
|
||||
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
|
||||
else:
|
||||
self.writer_scp = None
|
||||
if write_num_frames is not None:
|
||||
self.writer_nframe = get_num_frames_writer(write_num_frames)
|
||||
else:
|
||||
self.writer_nframe = None
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
assert_scipy_wav_style(value)
|
||||
rate, signal = value
|
||||
wavfile = Path(self.dirname) / (key + "." + self.pcm_format)
|
||||
soundfile.write(wavfile, signal.astype(numpy.int16), rate)
|
||||
|
||||
if self.writer_scp is not None:
|
||||
self.writer_scp.write(f"{key} {wavfile}\n")
|
||||
if self.writer_nframe is not None:
|
||||
self.writer_nframe.write(f"{key} {len(signal)}\n")
|
@ -0,0 +1,4 @@
|
||||
dump
|
||||
fbank
|
||||
exp
|
||||
data
|
@ -1,122 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
data:
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
min_input_len: 0.5
|
||||
max_input_len: 20.0
|
||||
min_output_len: 0.0
|
||||
max_output_len: 400.0
|
||||
min_output_input_ratio: 0.05
|
||||
max_output_input_ratio: 10.0
|
||||
|
||||
collator:
|
||||
vocab_filepath: data/vocab.txt
|
||||
unit_type: 'spm'
|
||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
||||
mean_std_filepath: ""
|
||||
augmentation_config: conf/augmentation.json
|
||||
batch_size: 16
|
||||
raw_wav: True # use raw_wav or kaldi feature
|
||||
spectrum_type: fbank #linear, mfcc, fbank
|
||||
feat_dim: 80
|
||||
delta_delta: False
|
||||
dither: 1.0
|
||||
target_sample_rate: 16000
|
||||
max_freq: None
|
||||
n_fft: None
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
random_seed: 0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
|
||||
# network architecture
|
||||
model:
|
||||
cmvn_file: "data/mean_std.json"
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: True
|
||||
use_cnn_module: True
|
||||
cnn_module_kernel: 15
|
||||
activation_type: 'swish'
|
||||
pos_enc_layer_type: 'rel_pos'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
causal: True
|
||||
use_dynamic_chunk: true
|
||||
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
|
||||
use_dynamic_left_chunk: false
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
ctc_dropoutrate: 0.0
|
||||
ctc_grad_norm_type: null
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
|
||||
training:
|
||||
n_epoch: 240
|
||||
accum_grad: 8
|
||||
global_grad_clip: 5.0
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 1e-06
|
||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
||||
|
||||
decoding:
|
||||
batch_size: 128
|
||||
error_rate_type: wer
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 10
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 0
|
||||
num_proc_bsearch: 8
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
||||
|
||||
|
@ -1,115 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
data:
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
min_input_len: 0.5 # second
|
||||
max_input_len: 20.0 # second
|
||||
min_output_len: 0.0 # tokens
|
||||
max_output_len: 400.0 # tokens
|
||||
min_output_input_ratio: 0.05
|
||||
max_output_input_ratio: 10.0
|
||||
|
||||
collator:
|
||||
vocab_filepath: data/vocab.txt
|
||||
unit_type: 'spm'
|
||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
||||
mean_std_filepath: ""
|
||||
augmentation_config: conf/augmentation.json
|
||||
batch_size: 64
|
||||
raw_wav: True # use raw_wav or kaldi feature
|
||||
spectrum_type: fbank #linear, mfcc, fbank
|
||||
feat_dim: 80
|
||||
delta_delta: False
|
||||
dither: 1.0
|
||||
target_sample_rate: 16000
|
||||
max_freq: None
|
||||
n_fft: None
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
random_seed: 0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
|
||||
# network architecture
|
||||
model:
|
||||
cmvn_file: "data/mean_std.json"
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: transformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: true
|
||||
use_dynamic_chunk: true
|
||||
use_dynamic_left_chunk: false
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
ctc_dropoutrate: 0.0
|
||||
ctc_grad_norm_type: null
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
|
||||
training:
|
||||
n_epoch: 120
|
||||
accum_grad: 1
|
||||
global_grad_clip: 5.0
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 1e-06
|
||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
||||
|
||||
decoding:
|
||||
batch_size: 64
|
||||
error_rate_type: wer
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 10
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 0
|
||||
num_proc_bsearch: 8
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: true # simulate streaming inference. Defaults to False.
|
||||
|
||||
|
@ -1,118 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
data:
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test-clean
|
||||
min_input_len: 0.5 # seconds
|
||||
max_input_len: 20.0 # seconds
|
||||
min_output_len: 0.0 # tokens
|
||||
max_output_len: 400.0 # tokens
|
||||
min_output_input_ratio: 0.05
|
||||
max_output_input_ratio: 10.0
|
||||
|
||||
collator:
|
||||
vocab_filepath: data/vocab.txt
|
||||
unit_type: 'spm'
|
||||
spm_model_prefix: 'data/bpe_unigram_5000'
|
||||
mean_std_filepath: ""
|
||||
augmentation_config: conf/augmentation.json
|
||||
batch_size: 16
|
||||
raw_wav: True # use raw_wav or kaldi feature
|
||||
spectrum_type: fbank #linear, mfcc, fbank
|
||||
feat_dim: 80
|
||||
delta_delta: False
|
||||
dither: 1.0
|
||||
target_sample_rate: 16000
|
||||
max_freq: None
|
||||
n_fft: None
|
||||
stride_ms: 10.0
|
||||
window_ms: 25.0
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
random_seed: 0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
|
||||
# network architecture
|
||||
model:
|
||||
cmvn_file: "data/mean_std.json"
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: True
|
||||
use_cnn_module: True
|
||||
cnn_module_kernel: 15
|
||||
activation_type: 'swish'
|
||||
pos_enc_layer_type: 'rel_pos'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
ctc_dropoutrate: 0.0
|
||||
ctc_grad_norm_type: null
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
|
||||
training:
|
||||
n_epoch: 120
|
||||
accum_grad: 8
|
||||
global_grad_clip: 3.0
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.004
|
||||
weight_decay: 1e-06
|
||||
scheduler: warmuplr # pytorch v1.1.0+ required
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
||||
|
||||
decoding:
|
||||
batch_size: 64
|
||||
error_rate_type: wer
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 10
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 0
|
||||
num_proc_bsearch: 8
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
--sample-frequency=16000
|
||||
--num-mel-bins=80
|
@ -0,0 +1,13 @@
|
||||
model_module: transformer
|
||||
model:
|
||||
n_vocab: 5002
|
||||
pos_enc: null
|
||||
embed_unit: 128
|
||||
att_unit: 512
|
||||
head: 8
|
||||
unit: 2048
|
||||
layer: 16
|
||||
dropout_rate: 0.5
|
||||
emb_dropout_rate: 0.0
|
||||
att_dropout_rate: 0.0
|
||||
tie_weights: False
|
@ -0,0 +1 @@
|
||||
--sample-frequency=16000
|
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2014 Vassil Panayotov
|
||||
# 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
# Apache 2.0
|
||||
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <src-dir> <dst-dir>"
|
||||
echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
src=$1
|
||||
dst=$2
|
||||
|
||||
# all utterances are FLAC compressed
|
||||
if ! which flac >&/dev/null; then
|
||||
echo "Please install 'flac' on ALL worker nodes!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
spk_file=$src/../SPEAKERS.TXT
|
||||
|
||||
mkdir -p $dst || exit 1
|
||||
|
||||
[ ! -d $src ] && echo "$0: no such directory $src" && exit 1
|
||||
[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1
|
||||
|
||||
|
||||
wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp
|
||||
trans=$dst/text; [[ -f "$trans" ]] && rm $trans
|
||||
utt2spk=$dst/utt2spk; [[ -f "$utt2spk" ]] && rm $utt2spk
|
||||
spk2gender=$dst/spk2gender; [[ -f $spk2gender ]] && rm $spk2gender
|
||||
|
||||
for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do
|
||||
reader=$(basename $reader_dir)
|
||||
if ! [ $reader -eq $reader ]; then # not integer.
|
||||
echo "$0: unexpected subdirectory name $reader"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
reader_gender=$(egrep "^$reader[ ]+\|" $spk_file | awk -F'|' '{gsub(/[ ]+/, ""); print tolower($2)}')
|
||||
if [ "$reader_gender" != 'm' ] && [ "$reader_gender" != 'f' ]; then
|
||||
echo "Unexpected gender: '$reader_gender'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do
|
||||
chapter=$(basename $chapter_dir)
|
||||
if ! [ "$chapter" -eq "$chapter" ]; then
|
||||
echo "$0: unexpected chapter-subdirectory name $chapter"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \
|
||||
awk -v "dir=$chapter_dir" '{printf "%s flac -c -d -s %s/%s.flac |\n", $0, dir, $0}' >>$wav_scp|| exit 1
|
||||
|
||||
chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt
|
||||
[ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1
|
||||
cat $chapter_trans >>$trans
|
||||
|
||||
# NOTE: For now we are using per-chapter utt2spk. That is each chapter is considered
|
||||
# to be a different speaker. This is done for simplicity and because we want
|
||||
# e.g. the CMVN to be calculated per-chapter
|
||||
awk -v "reader=$reader" -v "chapter=$chapter" '{printf "%s %s-%s\n", $1, reader, chapter}' \
|
||||
<$chapter_trans >>$utt2spk || exit 1
|
||||
|
||||
# reader -> gender map (again using per-chapter granularity)
|
||||
echo "${reader}-${chapter} $reader_gender" >>$spk2gender
|
||||
done
|
||||
done
|
||||
|
||||
spk2utt=$dst/spk2utt
|
||||
utils/utt2spk_to_spk2utt.pl <$utt2spk >$spk2utt || exit 1
|
||||
|
||||
ntrans=$(wc -l <$trans)
|
||||
nutt2spk=$(wc -l <$utt2spk)
|
||||
! [ "$ntrans" -eq "$nutt2spk" ] && \
|
||||
echo "Inconsistent #transcripts($ntrans) and #utt2spk($nutt2spk)" && exit 1
|
||||
|
||||
utils/validate_data_dir.sh --no-feats $dst || exit 1
|
||||
|
||||
echo "$0: successfully prepared data in $dst"
|
||||
|
||||
exit 0
|
@ -0,0 +1 @@
|
||||
../../../tools/kaldi/egs/wsj/s5/steps/
|
@ -1 +1 @@
|
||||
../../../utils/
|
||||
../../../tools/kaldi/egs/wsj/s5/utils
|
@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
from distutils.util import strtobool
|
||||
|
||||
import kaldiio
|
||||
import numpy
|
||||
|
||||
from deepspeech.transform.cmvn import CMVN
|
||||
from deepspeech.utils.cli_readers import file_reader_helper
|
||||
from deepspeech.utils.cli_utils import get_commandline_args
|
||||
from deepspeech.utils.cli_utils import is_scipy_wav_style
|
||||
from deepspeech.utils.cli_writers import file_writer_helper
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="apply mean-variance normalization to files",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose", "-V", default=0, type=int, help="Verbose option")
|
||||
parser.add_argument(
|
||||
"--in-filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5", "sound.hdf5", "sound"],
|
||||
help="Specify the file format for the rspecifier. "
|
||||
'"mat" is the matrix format in kaldi', )
|
||||
parser.add_argument(
|
||||
"--stats-filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5", "npy"],
|
||||
help="Specify the file format for the rspecifier. "
|
||||
'"mat" is the matrix format in kaldi', )
|
||||
parser.add_argument(
|
||||
"--out-filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5"],
|
||||
help="Specify the file format for the wspecifier. "
|
||||
'"mat" is the matrix format in kaldi', )
|
||||
|
||||
parser.add_argument(
|
||||
"--norm-means",
|
||||
type=strtobool,
|
||||
default=True,
|
||||
help="Do variance normalization or not.", )
|
||||
parser.add_argument(
|
||||
"--norm-vars",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Do variance normalization or not.", )
|
||||
parser.add_argument(
|
||||
"--reverse",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Do reverse mode or not")
|
||||
parser.add_argument(
|
||||
"--spk2utt",
|
||||
type=str,
|
||||
help="A text file of speaker to utterance-list map. "
|
||||
"(Don't give rspecifier format, such as "
|
||||
'"ark:spk2utt")', )
|
||||
parser.add_argument(
|
||||
"--utt2spk",
|
||||
type=str,
|
||||
help="A text file of utterance to speaker map. "
|
||||
"(Don't give rspecifier format, such as "
|
||||
'"ark:utt2spk")', )
|
||||
parser.add_argument(
|
||||
"--write-num-frames",
|
||||
type=str,
|
||||
help="Specify wspecifer for utt2num_frames")
|
||||
parser.add_argument(
|
||||
"--compress",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Save in compressed format")
|
||||
parser.add_argument(
|
||||
"--compression-method",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Specify the method(if mat) or "
|
||||
"gzip-level(if hdf5)", )
|
||||
parser.add_argument(
|
||||
"stats_rspecifier_or_rxfilename",
|
||||
help="Input stats. e.g. ark:stats.ark or stats.mat", )
|
||||
parser.add_argument(
|
||||
"rspecifier", type=str, help="Read specifier id. e.g. ark:some.ark")
|
||||
parser.add_argument(
|
||||
"wspecifier", type=str, help="Write specifier id. e.g. ark:some.ark")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
# logging info
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
if args.verbose > 0:
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARN, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
if ":" in args.stats_rspecifier_or_rxfilename:
|
||||
is_rspcifier = True
|
||||
if args.stats_filetype == "npy":
|
||||
stats_filetype = "hdf5"
|
||||
else:
|
||||
stats_filetype = args.stats_filetype
|
||||
|
||||
stats_dict = dict(
|
||||
file_reader_helper(args.stats_rspecifier_or_rxfilename,
|
||||
stats_filetype))
|
||||
else:
|
||||
is_rspcifier = False
|
||||
if args.stats_filetype == "mat":
|
||||
stats = kaldiio.load_mat(args.stats_rspecifier_or_rxfilename)
|
||||
else:
|
||||
stats = numpy.load(args.stats_rspecifier_or_rxfilename)
|
||||
stats_dict = {None: stats}
|
||||
|
||||
cmvn = CMVN(
|
||||
stats=stats_dict,
|
||||
norm_means=args.norm_means,
|
||||
norm_vars=args.norm_vars,
|
||||
utt2spk=args.utt2spk,
|
||||
spk2utt=args.spk2utt,
|
||||
reverse=args.reverse, )
|
||||
|
||||
with file_writer_helper(
|
||||
args.wspecifier,
|
||||
filetype=args.out_filetype,
|
||||
write_num_frames=args.write_num_frames,
|
||||
compress=args.compress,
|
||||
compression_method=args.compression_method, ) as writer:
|
||||
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
|
||||
if is_scipy_wav_style(mat):
|
||||
# If data is sound file, then got as Tuple[int, ndarray]
|
||||
rate, mat = mat
|
||||
mat = cmvn(mat, utt if is_rspcifier else None)
|
||||
writer[utt] = mat
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: utf-8
|
||||
# Copyright 2021 Kyoto University (Hirofumi Inaguma)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
import argparse
|
||||
import codecs
|
||||
import glob
|
||||
import os
|
||||
|
||||
from dateutil import parser
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="calculate real time factor (RTF)")
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to logging directory", )
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = get_parser().parse_args()
|
||||
|
||||
audio_sec = 0
|
||||
decode_sec = 0
|
||||
n_utt = 0
|
||||
|
||||
audio_durations = []
|
||||
start_times = []
|
||||
end_times = []
|
||||
for x in glob.glob(os.path.join(args.log_dir, "decode.*.log")):
|
||||
with codecs.open(x, "r", "utf-8") as f:
|
||||
for line in f:
|
||||
x = line.strip()
|
||||
# 2021-10-25 08:22:04.052 | INFO | xxx:recog_v2:188 - feat: (1570, 83)
|
||||
if "feat:" in x:
|
||||
dur = int(x.split("(")[1].split(',')[0])
|
||||
audio_durations += [dur]
|
||||
start_times += [parser.parse(x.split("|")[0])]
|
||||
elif "total log probability:" in x:
|
||||
end_times += [parser.parse(x.split("|")[0])]
|
||||
assert len(audio_durations) == len(end_times), (len(audio_durations),
|
||||
len(end_times), )
|
||||
assert len(start_times) == len(end_times), (len(start_times),
|
||||
len(end_times))
|
||||
|
||||
audio_sec += sum(audio_durations) / 100 # [sec]
|
||||
decode_sec += sum([(end - start).total_seconds()
|
||||
for start, end in zip(start_times, end_times)])
|
||||
n_utt += len(audio_durations)
|
||||
|
||||
print("Total audio duration: %.3f [sec]" % audio_sec)
|
||||
print("Total decoding time: %.3f [sec]" % decode_sec)
|
||||
rtf = decode_sec / audio_sec if audio_sec > 0 else 0
|
||||
print("RTF: %.3f" % rtf)
|
||||
latency = decode_sec * 1000 / n_utt if n_utt > 0 else 0
|
||||
print("Latency: %.3f [ms/sentence]" % latency)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
|
||||
from deepspeech.transform.transformation import Transformation
|
||||
from deepspeech.utils.cli_readers import file_reader_helper
|
||||
from deepspeech.utils.cli_utils import get_commandline_args
|
||||
from deepspeech.utils.cli_utils import is_scipy_wav_style
|
||||
from deepspeech.utils.cli_writers import file_writer_helper
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute cepstral mean and "
|
||||
"variance normalization statistics"
|
||||
"If wspecifier provided: per-utterance by default, "
|
||||
"or per-speaker if"
|
||||
"spk2utt option provided; if wxfilename: global",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
|
||||
parser.add_argument(
|
||||
"--spk2utt",
|
||||
type=str,
|
||||
help="A text file of speaker to utterance-list map. "
|
||||
"(Don't give rspecifier format, such as "
|
||||
'"ark:utt2spk")', )
|
||||
parser.add_argument(
|
||||
"--verbose", "-V", default=0, type=int, help="Verbose option")
|
||||
parser.add_argument(
|
||||
"--in-filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5", "sound.hdf5", "sound"],
|
||||
help="Specify the file format for the rspecifier. "
|
||||
'"mat" is the matrix format in kaldi', )
|
||||
parser.add_argument(
|
||||
"--out-filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5", "npy"],
|
||||
help="Specify the file format for the wspecifier. "
|
||||
'"mat" is the matrix format in kaldi', )
|
||||
parser.add_argument(
|
||||
"--preprocess-conf",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration file for the pre-processing", )
|
||||
parser.add_argument(
|
||||
"rspecifier",
|
||||
type=str,
|
||||
help="Read specifier for feats. e.g. ark:some.ark")
|
||||
parser.add_argument(
|
||||
"wspecifier_or_wxfilename",
|
||||
type=str,
|
||||
help="Write specifier. e.g. ark:some.ark")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
if args.verbose > 0:
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARN, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
is_wspecifier = ":" in args.wspecifier_or_wxfilename
|
||||
|
||||
if is_wspecifier:
|
||||
if args.spk2utt is not None:
|
||||
logging.info("Performing as speaker CMVN mode")
|
||||
utt2spk_dict = {}
|
||||
with open(args.spk2utt) as f:
|
||||
for line in f:
|
||||
spk, utts = line.rstrip().split(None, 1)
|
||||
for utt in utts.split():
|
||||
utt2spk_dict[utt] = spk
|
||||
|
||||
def utt2spk(x):
|
||||
return utt2spk_dict[x]
|
||||
|
||||
else:
|
||||
logging.info("Performing as utterance CMVN mode")
|
||||
|
||||
def utt2spk(x):
|
||||
return x
|
||||
|
||||
if args.out_filetype == "npy":
|
||||
logging.warning("--out-filetype npy is allowed only for "
|
||||
"Global CMVN mode, changing to hdf5")
|
||||
args.out_filetype = "hdf5"
|
||||
|
||||
else:
|
||||
logging.info("Performing as global CMVN mode")
|
||||
if args.spk2utt is not None:
|
||||
logging.warning("spk2utt is not used for global CMVN mode")
|
||||
|
||||
def utt2spk(x):
|
||||
return None
|
||||
|
||||
if args.out_filetype == "hdf5":
|
||||
logging.warning("--out-filetype hdf5 is not allowed for "
|
||||
"Global CMVN mode, changing to npy")
|
||||
args.out_filetype = "npy"
|
||||
|
||||
if args.preprocess_conf is not None:
|
||||
preprocessing = Transformation(args.preprocess_conf)
|
||||
logging.info("Apply preprocessing: {}".format(preprocessing))
|
||||
else:
|
||||
preprocessing = None
|
||||
|
||||
# Calculate stats for each speaker
|
||||
counts = {}
|
||||
sum_feats = {}
|
||||
square_sum_feats = {}
|
||||
|
||||
idx = 0
|
||||
for idx, (utt, matrix) in enumerate(
|
||||
file_reader_helper(args.rspecifier, args.in_filetype), 1):
|
||||
if is_scipy_wav_style(matrix):
|
||||
# If data is sound file, then got as Tuple[int, ndarray]
|
||||
rate, matrix = matrix
|
||||
if preprocessing is not None:
|
||||
matrix = preprocessing(matrix, uttid_list=utt)
|
||||
|
||||
spk = utt2spk(utt)
|
||||
|
||||
# Init at the first seen of the spk
|
||||
if spk not in counts:
|
||||
counts[spk] = 0
|
||||
feat_shape = matrix.shape[1:]
|
||||
# Accumulate in double precision
|
||||
sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
|
||||
square_sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
|
||||
|
||||
counts[spk] += matrix.shape[0]
|
||||
sum_feats[spk] += matrix.sum(axis=0)
|
||||
square_sum_feats[spk] += (matrix**2).sum(axis=0)
|
||||
logging.info("Processed {} utterances".format(idx))
|
||||
assert idx > 0, idx
|
||||
|
||||
cmvn_stats = {}
|
||||
for spk in counts:
|
||||
feat_shape = sum_feats[spk].shape
|
||||
cmvn_shape = (2, feat_shape[0] + 1) + feat_shape[1:]
|
||||
_cmvn_stats = np.empty(cmvn_shape, dtype=np.float64)
|
||||
_cmvn_stats[0, :-1] = sum_feats[spk]
|
||||
_cmvn_stats[1, :-1] = square_sum_feats[spk]
|
||||
|
||||
_cmvn_stats[0, -1] = counts[spk]
|
||||
_cmvn_stats[1, -1] = 0.0
|
||||
|
||||
# You can get the mean and std as following,
|
||||
# >>> N = _cmvn_stats[0, -1]
|
||||
# >>> mean = _cmvn_stats[0, :-1] / N
|
||||
# >>> std = np.sqrt(_cmvn_stats[1, :-1] / N - mean ** 2)
|
||||
|
||||
cmvn_stats[spk] = _cmvn_stats
|
||||
|
||||
# Per utterance or speaker CMVN
|
||||
if is_wspecifier:
|
||||
with file_writer_helper(
|
||||
args.wspecifier_or_wxfilename,
|
||||
filetype=args.out_filetype) as writer:
|
||||
for spk, mat in cmvn_stats.items():
|
||||
writer[spk] = mat
|
||||
|
||||
# Global CMVN
|
||||
else:
|
||||
matrix = cmvn_stats[None]
|
||||
if args.out_filetype == "npy":
|
||||
np.save(args.wspecifier_or_wxfilename, matrix)
|
||||
elif args.out_filetype == "mat":
|
||||
# Kaldi supports only matrix or vector
|
||||
kaldiio.save_mat(args.wspecifier_or_wxfilename, matrix)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Not supporting: --out-filetype {}".format(args.out_filetype))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
from distutils.util import strtobool
|
||||
|
||||
from deepspeech.transform.transformation import Transformation
|
||||
from deepspeech.utils.cli_readers import file_reader_helper
|
||||
from deepspeech.utils.cli_utils import get_commandline_args
|
||||
from deepspeech.utils.cli_utils import is_scipy_wav_style
|
||||
from deepspeech.utils.cli_writers import file_writer_helper
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="copy feature with preprocessing",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose", "-V", default=0, type=int, help="Verbose option")
|
||||
parser.add_argument(
|
||||
"--in-filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5", "sound.hdf5", "sound"],
|
||||
help="Specify the file format for the rspecifier. "
|
||||
'"mat" is the matrix format in kaldi', )
|
||||
parser.add_argument(
|
||||
"--out-filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5", "sound.hdf5", "sound"],
|
||||
help="Specify the file format for the wspecifier. "
|
||||
'"mat" is the matrix format in kaldi', )
|
||||
parser.add_argument(
|
||||
"--write-num-frames",
|
||||
type=str,
|
||||
help="Specify wspecifer for utt2num_frames")
|
||||
parser.add_argument(
|
||||
"--compress",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Save in compressed format")
|
||||
parser.add_argument(
|
||||
"--compression-method",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Specify the method(if mat) or "
|
||||
"gzip-level(if hdf5)", )
|
||||
parser.add_argument(
|
||||
"--preprocess-conf",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration file for the pre-processing", )
|
||||
parser.add_argument(
|
||||
"rspecifier",
|
||||
type=str,
|
||||
help="Read specifier for feats. e.g. ark:some.ark")
|
||||
parser.add_argument(
|
||||
"wspecifier", type=str, help="Write specifier. e.g. ark:some.ark")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# logging info
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
if args.verbose > 0:
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARN, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
if args.preprocess_conf is not None:
|
||||
preprocessing = Transformation(args.preprocess_conf)
|
||||
logging.info("Apply preprocessing: {}".format(preprocessing))
|
||||
else:
|
||||
preprocessing = None
|
||||
|
||||
with file_writer_helper(
|
||||
args.wspecifier,
|
||||
filetype=args.out_filetype,
|
||||
write_num_frames=args.write_num_frames,
|
||||
compress=args.compress,
|
||||
compression_method=args.compression_method, ) as writer:
|
||||
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
|
||||
if is_scipy_wav_style(mat):
|
||||
# If data is sound file, then got as Tuple[int, ndarray]
|
||||
rate, mat = mat
|
||||
|
||||
if preprocessing is not None:
|
||||
mat = preprocessing(mat, uttid_list=utt)
|
||||
|
||||
# shape = (Time, Channel)
|
||||
if args.out_filetype in ["sound.hdf5", "sound"]:
|
||||
# Write Tuple[int, numpy.ndarray] (scipy style)
|
||||
writer[utt] = (rate, mat)
|
||||
else:
|
||||
writer[utt] = mat
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
echo "$0 $*" >&2 # Print the command line for logging
|
||||
. ./path.sh
|
||||
|
||||
nj=1
|
||||
cmd=run.pl
|
||||
nlsyms=""
|
||||
lang=""
|
||||
feat="" # feat.scp
|
||||
oov="<unk>"
|
||||
bpecode=""
|
||||
allow_one_column=false
|
||||
verbose=0
|
||||
trans_type=char
|
||||
filetype=""
|
||||
preprocess_conf=""
|
||||
category=""
|
||||
out="" # If omitted, write in stdout
|
||||
|
||||
text=""
|
||||
multilingual=false
|
||||
|
||||
help_message=$(cat << EOF
|
||||
Usage: $0 <data-dir> <dict>
|
||||
e.g. $0 data/train data/lang_1char/train_units.txt
|
||||
Options:
|
||||
--nj <nj> # number of parallel jobs
|
||||
--cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs.
|
||||
--feat <feat-scp> # feat.scp or feat1.scp,feat2.scp,...
|
||||
--oov <oov-word> # Default: <unk>
|
||||
--out <outputfile> # If omitted, write in stdout
|
||||
--filetype <mat|hdf5|sound.hdf5> # Specify the format of feats file
|
||||
--preprocess-conf <json> # Apply preprocess to feats when creating shape.scp
|
||||
--verbose <num> # Default: 0
|
||||
EOF
|
||||
)
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "${help_message}" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
dir=$1
|
||||
dic=$2
|
||||
tmpdir=$(mktemp -d ${dir}/tmp-XXXXX)
|
||||
trap 'rm -rf ${tmpdir}' EXIT
|
||||
|
||||
if [ -z ${text} ]; then
|
||||
text=${dir}/text
|
||||
fi
|
||||
|
||||
# 1. Create scp files for inputs
|
||||
# These are not necessary for decoding mode, and make it as an option
|
||||
input=
|
||||
if [ -n "${feat}" ]; then
|
||||
_feat_scps=$(echo "${feat}" | tr ',' ' ' )
|
||||
read -r -a feat_scps <<< $_feat_scps
|
||||
num_feats=${#feat_scps[@]}
|
||||
|
||||
for (( i=1; i<=num_feats; i++ )); do
|
||||
feat=${feat_scps[$((i-1))]}
|
||||
mkdir -p ${tmpdir}/input_${i}
|
||||
input+="input_${i} "
|
||||
cat ${feat} > ${tmpdir}/input_${i}/feat.scp
|
||||
|
||||
# Dump in the "legacy" style JSON format
|
||||
if [ -n "${filetype}" ]; then
|
||||
awk -v filetype=${filetype} '{print $1 " " filetype}' ${feat} \
|
||||
> ${tmpdir}/input_${i}/filetype.scp
|
||||
fi
|
||||
|
||||
feat_to_shape.sh --cmd "${cmd}" --nj ${nj} \
|
||||
--filetype "${filetype}" \
|
||||
--preprocess-conf "${preprocess_conf}" \
|
||||
--verbose ${verbose} ${feat} ${tmpdir}/input_${i}/shape.scp
|
||||
done
|
||||
fi
|
||||
|
||||
# 2. Create scp files for outputs
|
||||
mkdir -p ${tmpdir}/output
|
||||
if [ -n "${bpecode}" ]; then
|
||||
if [ ${multilingual} = true ]; then
|
||||
# remove a space before the language ID
|
||||
paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \
|
||||
| spm_encode --model=${bpecode} --output_format=piece | cut -f 2- -d" ") \
|
||||
> ${tmpdir}/output/token.scp
|
||||
else
|
||||
paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \
|
||||
| spm_encode --model=${bpecode} --output_format=piece) \
|
||||
> ${tmpdir}/output/token.scp
|
||||
fi
|
||||
elif [ -n "${nlsyms}" ]; then
|
||||
text2token.py -s 1 -n 1 -l ${nlsyms} ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp
|
||||
else
|
||||
text2token.py -s 1 -n 1 ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp
|
||||
fi
|
||||
< ${tmpdir}/output/token.scp utils/sym2int.pl --map-oov ${oov} -f 2- ${dic} > ${tmpdir}/output/tokenid.scp
|
||||
# +2 comes from CTC blank and EOS
|
||||
vocsize=$(tail -n 1 ${dic} | awk '{print $2}')
|
||||
odim=$(echo "$vocsize + 2" | bc)
|
||||
< ${tmpdir}/output/tokenid.scp awk -v odim=${odim} '{print $1 " " NF-1 "," odim}' > ${tmpdir}/output/shape.scp
|
||||
|
||||
cat ${text} > ${tmpdir}/output/text.scp
|
||||
|
||||
|
||||
# 3. Create scp files for the others
|
||||
mkdir -p ${tmpdir}/other
|
||||
if [ ${multilingual} == true ]; then
|
||||
awk '{
|
||||
n = split($1,S,"[-]");
|
||||
lang=S[n];
|
||||
print $1 " " lang
|
||||
}' ${text} > ${tmpdir}/other/lang.scp
|
||||
elif [ -n "${lang}" ]; then
|
||||
awk -v lang=${lang} '{print $1 " " lang}' ${text} > ${tmpdir}/other/lang.scp
|
||||
fi
|
||||
|
||||
if [ -n "${category}" ]; then
|
||||
awk -v category=${category} '{print $1 " " category}' ${dir}/text \
|
||||
> ${tmpdir}/other/category.scp
|
||||
fi
|
||||
cat ${dir}/utt2spk > ${tmpdir}/other/utt2spk.scp
|
||||
|
||||
# 4. Merge scp files into a JSON file
|
||||
opts=""
|
||||
if [ -n "${feat}" ]; then
|
||||
intypes="${input} output other"
|
||||
else
|
||||
intypes="output other"
|
||||
fi
|
||||
for intype in ${intypes}; do
|
||||
if [ -z "$(find "${tmpdir}/${intype}" -name "*.scp")" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
if [ ${intype} != other ]; then
|
||||
opts+="--${intype%_*}-scps "
|
||||
else
|
||||
opts+="--scps "
|
||||
fi
|
||||
|
||||
for x in "${tmpdir}/${intype}"/*.scp; do
|
||||
k=$(basename ${x} .scp)
|
||||
if [ ${k} = shape ]; then
|
||||
opts+="shape:${x}:shape "
|
||||
else
|
||||
opts+="${k}:${x} "
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if ${allow_one_column}; then
|
||||
opts+="--allow-one-column true "
|
||||
else
|
||||
opts+="--allow-one-column false "
|
||||
fi
|
||||
|
||||
if [ -n "${out}" ]; then
|
||||
opts+="-O ${out}"
|
||||
fi
|
||||
merge_scp2json.py --verbose ${verbose} ${opts}
|
||||
|
||||
rm -fr ${tmpdir}
|
@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2017 Nagoya University (Tomoki Hayashi)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
echo "$0 $*" # Print the command line for logging
|
||||
. ./path.sh
|
||||
|
||||
cmd=run.pl
|
||||
do_delta=false
|
||||
nj=1
|
||||
verbose=0
|
||||
compress=true
|
||||
write_utt2num_frames=true
|
||||
filetype='mat' # mat or hdf5
|
||||
help_message="Usage: $0 <scp> <cmvnark> <logdir> <dumpdir>"
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
scp=$1
|
||||
cvmnark=$2
|
||||
logdir=$3
|
||||
dumpdir=$4
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
echo "${help_message}"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p ${logdir}
|
||||
mkdir -p ${dumpdir}
|
||||
|
||||
dumpdir=$(perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' ${dumpdir} ${PWD})
|
||||
|
||||
for n in $(seq ${nj}); do
|
||||
# the next command does nothing unless $dumpdir/storage/ exists, see
|
||||
# utils/create_data_link.pl for more info.
|
||||
utils/create_data_link.pl ${dumpdir}/feats.${n}.ark
|
||||
done
|
||||
|
||||
if ${write_utt2num_frames}; then
|
||||
write_num_frames_opt="--write-num-frames=ark,t:$dumpdir/utt2num_frames.JOB"
|
||||
else
|
||||
write_num_frames_opt=
|
||||
fi
|
||||
|
||||
# split scp file
|
||||
split_scps=""
|
||||
for n in $(seq ${nj}); do
|
||||
split_scps="$split_scps $logdir/feats.$n.scp"
|
||||
done
|
||||
|
||||
utils/split_scp.pl ${scp} ${split_scps} || exit 1;
|
||||
|
||||
# dump features
|
||||
if ${do_delta}; then
|
||||
${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \
|
||||
apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \
|
||||
add-deltas ark:- ark:- \| \
|
||||
copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \
|
||||
--compress=${compress} --compression-method=2 ${write_num_frames_opt} \
|
||||
ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \
|
||||
|| exit 1
|
||||
else
|
||||
${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \
|
||||
apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \
|
||||
copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \
|
||||
--compress=${compress} --compression-method=2 ${write_num_frames_opt} \
|
||||
ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \
|
||||
|| exit 1
|
||||
fi
|
||||
|
||||
# concatenate scp files
|
||||
for n in $(seq ${nj}); do
|
||||
cat ${dumpdir}/feats.${n}.scp || exit 1;
|
||||
done > ${dumpdir}/feats.scp || exit 1
|
||||
|
||||
if ${write_utt2num_frames}; then
|
||||
for n in $(seq ${nj}); do
|
||||
cat ${dumpdir}/utt2num_frames.${n} || exit 1;
|
||||
done > ${dumpdir}/utt2num_frames || exit 1
|
||||
rm ${dumpdir}/utt2num_frames.* 2>/dev/null
|
||||
fi
|
||||
|
||||
# Write the filetype, this will be used for data2json.sh
|
||||
echo ${filetype} > ${dumpdir}/filetype
|
||||
|
||||
|
||||
# remove temp scps
|
||||
rm ${logdir}/feats.*.scp 2>/dev/null
|
||||
if [ ${verbose} -eq 1 ]; then
|
||||
echo "Succeeded dumping features for training"
|
||||
fi
|
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from deepspeech.transform.transformation import Transformation
|
||||
from deepspeech.utils.cli_readers import file_reader_helper
|
||||
from deepspeech.utils.cli_utils import get_commandline_args
|
||||
from deepspeech.utils.cli_utils import is_scipy_wav_style
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="convert feature to its shape",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
|
||||
parser.add_argument(
|
||||
"--filetype",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "hdf5", "sound.hdf5", "sound"],
|
||||
help="Specify the file format for the rspecifier. "
|
||||
'"mat" is the matrix format in kaldi',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocess-conf",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration file for the pre-processing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"out",
|
||||
nargs="?",
|
||||
type=argparse.FileType("w"),
|
||||
default=sys.stdout,
|
||||
help="The output filename. " "If omitted, then output to sys.stdout",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# logging info
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
if args.verbose > 0:
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARN, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
if args.preprocess_conf is not None:
|
||||
preprocessing = Transformation(args.preprocess_conf)
|
||||
logging.info("Apply preprocessing: {}".format(preprocessing))
|
||||
else:
|
||||
preprocessing = None
|
||||
|
||||
# There are no necessary for matrix without preprocessing,
|
||||
# so change to file_reader_helper to return shape.
|
||||
# This make sense only with filetype="hdf5".
|
||||
for utt, mat in file_reader_helper(
|
||||
args.rspecifier, args.filetype, return_shape=preprocessing is None
|
||||
):
|
||||
if preprocessing is not None:
|
||||
if is_scipy_wav_style(mat):
|
||||
# If data is sound file, then got as Tuple[int, ndarray]
|
||||
rate, mat = mat
|
||||
mat = preprocessing(mat, uttid_list=utt)
|
||||
shape_str = ",".join(map(str, mat.shape))
|
||||
else:
|
||||
if len(mat) == 2 and isinstance(mat[1], tuple):
|
||||
# If data is sound file, Tuple[int, Tuple[int, ...]]
|
||||
rate, mat = mat
|
||||
shape_str = ",".join(map(str, mat))
|
||||
args.out.write("{} {}\n".format(utt, shape_str))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Begin configuration section.
|
||||
nj=4
|
||||
cmd=run.pl
|
||||
verbose=0
|
||||
filetype=""
|
||||
preprocess_conf=""
|
||||
# End configuration section.
|
||||
|
||||
help_message=$(cat << EOF
|
||||
Usage: $0 [options] <input-scp> <output-scp> [<log-dir>]
|
||||
e.g.: $0 data/train/feats.scp data/train/shape.scp data/train/log
|
||||
Options:
|
||||
--nj <nj> # number of parallel jobs
|
||||
--cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs.
|
||||
--filetype <mat|hdf5|sound.hdf5> # Specify the format of feats file
|
||||
--preprocess-conf <json> # Apply preprocess to feats when creating shape.scp
|
||||
--verbose <num> # Default: 0
|
||||
EOF
|
||||
)
|
||||
|
||||
echo "$0 $*" 1>&2 # Print the command line for logging
|
||||
|
||||
. parse_options.sh || exit 1;
|
||||
|
||||
if [ $# -lt 2 ] || [ $# -gt 3 ]; then
|
||||
echo "${help_message}" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
scp=$1
|
||||
outscp=$2
|
||||
data=$(dirname ${scp})
|
||||
if [ $# -eq 3 ]; then
|
||||
logdir=$3
|
||||
else
|
||||
logdir=${data}/log
|
||||
fi
|
||||
mkdir -p ${logdir}
|
||||
|
||||
nj=$((nj<$(<"${scp}" wc -l)?nj:$(<"${scp}" wc -l)))
|
||||
split_scps=""
|
||||
for n in $(seq ${nj}); do
|
||||
split_scps="${split_scps} ${logdir}/feats.${n}.scp"
|
||||
done
|
||||
|
||||
utils/split_scp.pl ${scp} ${split_scps}
|
||||
|
||||
if [ -n "${preprocess_conf}" ]; then
|
||||
preprocess_opt="--preprocess-conf ${preprocess_conf}"
|
||||
else
|
||||
preprocess_opt=""
|
||||
fi
|
||||
if [ -n "${filetype}" ]; then
|
||||
filetype_opt="--filetype ${filetype}"
|
||||
else
|
||||
filetype_opt=""
|
||||
fi
|
||||
|
||||
${cmd} JOB=1:${nj} ${logdir}/feat_to_shape.JOB.log \
|
||||
feat-to-shape.py --verbose ${verbose} ${preprocess_opt} ${filetype_opt} \
|
||||
scp:${logdir}/feats.JOB.scp ${logdir}/shape.JOB.scp
|
||||
|
||||
# concatenate the .scp files together.
|
||||
for n in $(seq ${nj}); do
|
||||
cat ${logdir}/shape.${n}.scp
|
||||
done > ${outscp}
|
||||
|
||||
rm -f ${logdir}/feats.*.scp 2>/dev/null
|
@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: utf-8
|
||||
import argparse
|
||||
import codecs
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from distutils.util import strtobool
|
||||
from io import open
|
||||
|
||||
from deepspeech.utils.cli_utils import get_commandline_args
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer)
|
||||
sys.stdout = codecs.getwriter("utf-8")(sys.stdout if PY2 else sys.stdout.buffer)
|
||||
|
||||
|
||||
# Special types:
|
||||
def shape(x):
|
||||
"""Change str to List[int]
|
||||
|
||||
>>> shape('3,5')
|
||||
[3, 5]
|
||||
>>> shape(' [3, 5] ')
|
||||
[3, 5]
|
||||
|
||||
"""
|
||||
|
||||
# x: ' [3, 5] ' -> '3, 5'
|
||||
x = x.strip()
|
||||
if x[0] == "[":
|
||||
x = x[1:]
|
||||
if x[-1] == "]":
|
||||
x = x[:-1]
|
||||
|
||||
return list(map(int, x.split(",")))
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Given each file paths with such format as "
|
||||
"<key>:<file>:<type>. type> can be omitted and the default "
|
||||
'is "str". e.g. {} '
|
||||
"--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape "
|
||||
"--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape "
|
||||
"--output-scps text:data/text shape:data/utt2text_shape:shape "
|
||||
"--scps utt2spk:data/utt2spk".format(sys.argv[0]),
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
|
||||
parser.add_argument(
|
||||
"--input-scps",
|
||||
type=str,
|
||||
nargs="*",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Json files for the inputs", )
|
||||
parser.add_argument(
|
||||
"--output-scps",
|
||||
type=str,
|
||||
nargs="*",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Json files for the outputs", )
|
||||
parser.add_argument(
|
||||
"--scps",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="The json files except for the input and outputs", )
|
||||
parser.add_argument(
|
||||
"--verbose", "-V", default=1, type=int, help="Verbose option")
|
||||
parser.add_argument(
|
||||
"--allow-one-column",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Allow one column in input scp files. "
|
||||
"In this case, the value will be empty string.", )
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
"-O",
|
||||
type=str,
|
||||
help="The output filename. "
|
||||
"If omitted, then output to sys.stdout", )
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.scps = [args.scps]
|
||||
|
||||
# logging info
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
if args.verbose > 0:
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARN, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
# List[List[Tuple[str, str, Callable[[str], Any], str, str]]]
|
||||
input_infos = []
|
||||
output_infos = []
|
||||
infos = []
|
||||
for lis_list, key_scps_list in [
|
||||
(input_infos, args.input_scps),
|
||||
(output_infos, args.output_scps),
|
||||
(infos, args.scps),
|
||||
]:
|
||||
for key_scps in key_scps_list:
|
||||
lis = []
|
||||
for key_scp in key_scps:
|
||||
sps = key_scp.split(":")
|
||||
if len(sps) == 2:
|
||||
key, scp = sps
|
||||
type_func = None
|
||||
type_func_str = "none"
|
||||
elif len(sps) == 3:
|
||||
key, scp, type_func_str = sps
|
||||
fail = False
|
||||
|
||||
try:
|
||||
# type_func: Callable[[str], Any]
|
||||
# e.g. type_func_str = "int" -> type_func = int
|
||||
type_func = eval(type_func_str)
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Unknown type: {}".format(type_func_str))
|
||||
|
||||
if not callable(type_func):
|
||||
raise RuntimeError(
|
||||
"Unknown type: {}".format(type_func_str))
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Format <key>:<filepath> "
|
||||
"or <key>:<filepath>:<type> "
|
||||
"e.g. feat:data/feat.scp "
|
||||
"or shape:data/feat.scp:shape: {}".format(key_scp))
|
||||
|
||||
for item in lis:
|
||||
if key == item[0]:
|
||||
raise RuntimeError('The key "{}" is duplicated: {} {}'.
|
||||
format(key, item[3], key_scp))
|
||||
|
||||
lis.append((key, scp, type_func, key_scp, type_func_str))
|
||||
lis_list.append(lis)
|
||||
|
||||
# Open scp files
|
||||
input_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
|
||||
for il in input_infos]
|
||||
output_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
|
||||
for il in output_infos]
|
||||
fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in infos]
|
||||
|
||||
# Note(kamo): What is done here?
|
||||
# The final goal is creating a JSON file such as.
|
||||
# {
|
||||
# "utts": {
|
||||
# "sample_id1": {(omitted)},
|
||||
# "sample_id2": {(omitted)},
|
||||
# ....
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# To reduce memory usage, reading the input text files for each lines
|
||||
# and writing JSON elements per samples.
|
||||
if args.out is None:
|
||||
out = sys.stdout
|
||||
else:
|
||||
out = open(args.out, "w", encoding="utf-8")
|
||||
out.write('{\n "utts": {\n')
|
||||
nutt = 0
|
||||
while True:
|
||||
nutt += 1
|
||||
# List[List[str]]
|
||||
input_lines = [[f.readline() for f in fl] for fl in input_fscps]
|
||||
output_lines = [[f.readline() for f in fl] for fl in output_fscps]
|
||||
lines = [[f.readline() for f in fl] for fl in fscps]
|
||||
|
||||
# Get the first line
|
||||
concat = sum(input_lines + output_lines + lines, [])
|
||||
if len(concat) == 0:
|
||||
break
|
||||
first = concat[0]
|
||||
|
||||
# Sanity check: Must be sorted by the first column and have same keys
|
||||
count = 0
|
||||
for ls_list in (input_lines, output_lines, lines):
|
||||
for ls in ls_list:
|
||||
for line in ls:
|
||||
if line == "" or first == "":
|
||||
if line != first:
|
||||
concat = sum(input_infos + output_infos + infos, [])
|
||||
raise RuntimeError("The number of lines mismatch "
|
||||
'between: "{}" and "{}"'.format(
|
||||
concat[0][1],
|
||||
concat[count][1]))
|
||||
|
||||
elif line.split()[0] != first.split()[0]:
|
||||
concat = sum(input_infos + output_infos + infos, [])
|
||||
raise RuntimeError(
|
||||
"The keys are mismatch at {}th line "
|
||||
'between "{}" and "{}":\n>>> {}\n>>> {}'.format(
|
||||
nutt,
|
||||
concat[0][1],
|
||||
concat[count][1],
|
||||
first.rstrip(),
|
||||
line.rstrip(), ))
|
||||
count += 1
|
||||
|
||||
# The end of file
|
||||
if first == "":
|
||||
if nutt != 1:
|
||||
out.write("\n")
|
||||
break
|
||||
if nutt != 1:
|
||||
out.write(",\n")
|
||||
|
||||
entry = {}
|
||||
for inout, _lines, _infos in [
|
||||
("input", input_lines, input_infos),
|
||||
("output", output_lines, output_infos),
|
||||
("other", lines, infos),
|
||||
]:
|
||||
|
||||
lis = []
|
||||
for idx, (line_list, info_list) in enumerate(
|
||||
zip(_lines, _infos), 1):
|
||||
if inout == "input":
|
||||
d = {"name": "input{}".format(idx)}
|
||||
elif inout == "output":
|
||||
d = {"name": "target{}".format(idx)}
|
||||
else:
|
||||
d = {}
|
||||
|
||||
# info_list: List[Tuple[str, str, Callable]]
|
||||
# line_list: List[str]
|
||||
for line, info in zip(line_list, info_list):
|
||||
sps = line.split(None, 1)
|
||||
if len(sps) < 2:
|
||||
if not args.allow_one_column:
|
||||
raise RuntimeError(
|
||||
"Format error {}th line in {}: "
|
||||
' Expecting "<key> <value>":\n>>> {}'.format(
|
||||
nutt, info[1], line))
|
||||
uttid = sps[0]
|
||||
value = ""
|
||||
else:
|
||||
uttid, value = sps
|
||||
|
||||
key = info[0]
|
||||
type_func = info[2]
|
||||
value = value.rstrip()
|
||||
|
||||
if type_func is not None:
|
||||
try:
|
||||
# type_func: Callable[[str], Any]
|
||||
value = type_func(value)
|
||||
except Exception:
|
||||
logging.error(
|
||||
'"{}" is an invalid function '
|
||||
"for the {} th line in {}: \n>>> {}".format(
|
||||
info[4], nutt, info[1], line))
|
||||
raise
|
||||
|
||||
d[key] = value
|
||||
lis.append(d)
|
||||
|
||||
if inout != "other":
|
||||
entry[inout] = lis
|
||||
else:
|
||||
# If key == 'other'. only has the first item
|
||||
entry.update(lis[0])
|
||||
|
||||
entry = json.dumps(
|
||||
entry,
|
||||
indent=4,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
separators=(",", ": "))
|
||||
# Add indent
|
||||
indent = " " * 2
|
||||
entry = ("\n" + indent).join(entry.split("\n"))
|
||||
|
||||
uttid = first.split()[0]
|
||||
out.write(' "{}": {}'.format(uttid, entry))
|
||||
|
||||
out.write(" }\n}\n")
|
||||
|
||||
logging.info("{} entries in {}".format(nutt, out.name))
|
@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# koried, 10/29/2012
|
||||
|
||||
# Reduce a data set based on a list of turn-ids
|
||||
|
||||
help_message="usage: $0 srcdir turnlist destdir"
|
||||
|
||||
if [ $1 == "--help" ]; then
|
||||
echo "${help_message}"
|
||||
exit 0;
|
||||
fi
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "${help_message}"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
srcdir=$1
|
||||
reclist=$2
|
||||
destdir=$3
|
||||
|
||||
if [ ! -f ${srcdir}/utt2spk ]; then
|
||||
echo "$0: no such file $srcdir/utt2spk"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
function do_filtering {
|
||||
# assumes the utt2spk and spk2utt files already exist.
|
||||
[ -f ${srcdir}/feats.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/feats.scp >${destdir}/feats.scp
|
||||
[ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/wav.scp >${destdir}/wav.scp
|
||||
[ -f ${srcdir}/text ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/text >${destdir}/text
|
||||
[ -f ${srcdir}/utt2num_frames ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/utt2num_frames >${destdir}/utt2num_frames
|
||||
[ -f ${srcdir}/spk2gender ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/spk2gender >${destdir}/spk2gender
|
||||
[ -f ${srcdir}/cmvn.scp ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/cmvn.scp >${destdir}/cmvn.scp
|
||||
if [ -f ${srcdir}/segments ]; then
|
||||
utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/segments >${destdir}/segments
|
||||
awk '{print $2;}' ${destdir}/segments | sort | uniq > ${destdir}/reco # recordings.
|
||||
# The next line would override the command above for wav.scp, which would be incorrect.
|
||||
[ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/reco <${srcdir}/wav.scp >${destdir}/wav.scp
|
||||
[ -f ${srcdir}/reco2file_and_channel ] && \
|
||||
utils/filter_scp.pl ${destdir}/reco <${srcdir}/reco2file_and_channel >${destdir}/reco2file_and_channel
|
||||
|
||||
# Filter the STM file for proper sclite scoring (this will also remove the comments lines)
|
||||
[ -f ${srcdir}/stm ] && utils/filter_scp.pl ${destdir}/reco < ${srcdir}/stm > ${destdir}/stm
|
||||
rm ${destdir}/reco
|
||||
fi
|
||||
srcutts=$(wc -l < ${srcdir}/utt2spk)
|
||||
destutts=$(wc -l < ${destdir}/utt2spk)
|
||||
echo "Reduced #utt from $srcutts to $destutts"
|
||||
}
|
||||
|
||||
mkdir -p ${destdir}
|
||||
|
||||
# filter the utt2spk based on the set of recordings
|
||||
utils/filter_scp.pl ${reclist} < ${srcdir}/utt2spk > ${destdir}/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl < ${destdir}/utt2spk > ${destdir}/spk2utt
|
||||
do_filtering;
|
@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
. ./path.sh
|
||||
|
||||
maxframes=2000
|
||||
minframes=10
|
||||
maxchars=200
|
||||
minchars=0
|
||||
nlsyms=""
|
||||
no_feat=false
|
||||
trans_type=char
|
||||
|
||||
help_message="usage: $0 olddatadir newdatadir"
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "${help_message}"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
sdir=$1
|
||||
odir=$2
|
||||
mkdir -p ${odir}/tmp
|
||||
|
||||
if [ ${no_feat} = true ]; then
|
||||
# for machine translation
|
||||
cut -d' ' -f 1 ${sdir}/text > ${odir}/tmp/reclist1
|
||||
else
|
||||
echo "extract utterances having less than $maxframes or more than $minframes frames"
|
||||
utils/data/get_utt2num_frames.sh ${sdir}
|
||||
< ${sdir}/utt2num_frames awk -v maxframes="$maxframes" '{ if ($2 < maxframes) print }' \
|
||||
| awk -v minframes="$minframes" '{ if ($2 > minframes) print }' \
|
||||
| awk '{print $1}' > ${odir}/tmp/reclist1
|
||||
fi
|
||||
|
||||
echo "extract utterances having less than $maxchars or more than $minchars characters"
|
||||
# counting number of chars. Use (NF - 1) instead of NF to exclude the utterance ID column
|
||||
if [ -z ${nlsyms} ]; then
|
||||
text2token.py -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \
|
||||
| awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \
|
||||
| awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \
|
||||
| awk '{print $1}' > ${odir}/tmp/reclist2
|
||||
else
|
||||
text2token.py -l ${nlsyms} -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \
|
||||
| awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \
|
||||
| awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \
|
||||
| awk '{print $1}' > ${odir}/tmp/reclist2
|
||||
fi
|
||||
|
||||
# extract common lines
|
||||
comm -12 <(sort ${odir}/tmp/reclist1) <(sort ${odir}/tmp/reclist2) > ${odir}/tmp/reclist
|
||||
|
||||
reduce_data_dir.sh ${sdir} ${odir}/tmp/reclist ${odir}
|
||||
utils/fix_data_dir.sh ${odir}
|
||||
|
||||
oldnum=$(wc -l ${sdir}/feats.scp | awk '{print $1}')
|
||||
newnum=$(wc -l ${odir}/feats.scp | awk '{print $1}')
|
||||
echo "change from $oldnum to $newnum"
|
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
import argparse
|
||||
import codecs
|
||||
import re
|
||||
import sys
|
||||
|
||||
is_python2 = sys.version_info[0] == 2
|
||||
|
||||
|
||||
def exist_or_not(i, match_pos):
|
||||
start_pos = None
|
||||
end_pos = None
|
||||
for pos in match_pos:
|
||||
if pos[0] <= i < pos[1]:
|
||||
start_pos = pos[0]
|
||||
end_pos = pos[1]
|
||||
break
|
||||
|
||||
return start_pos, end_pos
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="convert raw text to tokenized text",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
|
||||
parser.add_argument(
|
||||
"--nchar",
|
||||
"-n",
|
||||
default=1,
|
||||
type=int,
|
||||
help="number of characters to split, i.e., \
|
||||
aabb -> a a b b with -n 1 and aa bb with -n 2", )
|
||||
parser.add_argument(
|
||||
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns")
|
||||
parser.add_argument(
|
||||
"--space", default="<space>", type=str, help="space symbol")
|
||||
parser.add_argument(
|
||||
"--non-lang-syms",
|
||||
"-l",
|
||||
default=None,
|
||||
type=str,
|
||||
help="list of non-linguistic symobles, e.g., <NOISE> etc.", )
|
||||
parser.add_argument(
|
||||
"text", type=str, default=False, nargs="?", help="input text")
|
||||
parser.add_argument(
|
||||
"--trans_type",
|
||||
"-t",
|
||||
type=str,
|
||||
default="char",
|
||||
choices=["char", "phn"],
|
||||
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
|
||||
If trans_type is char,
|
||||
read from SI1279.WRD file -> "bricks are an alternative"
|
||||
Else if trans_type is phn,
|
||||
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
|
||||
sil t er n ih sil t ih v sil" """, )
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
rs = []
|
||||
if args.non_lang_syms is not None:
|
||||
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
|
||||
nls = [x.rstrip() for x in f.readlines()]
|
||||
rs = [re.compile(re.escape(x)) for x in nls]
|
||||
|
||||
if args.text:
|
||||
f = codecs.open(args.text, encoding="utf-8")
|
||||
else:
|
||||
f = codecs.getreader("utf-8")(sys.stdin
|
||||
if is_python2 else sys.stdin.buffer)
|
||||
|
||||
sys.stdout = codecs.getwriter("utf-8")(sys.stdout
|
||||
if is_python2 else sys.stdout.buffer)
|
||||
line = f.readline()
|
||||
n = args.nchar
|
||||
while line:
|
||||
x = line.split()
|
||||
print(" ".join(x[:args.skip_ncols]), end=" ")
|
||||
a = " ".join(x[args.skip_ncols:])
|
||||
|
||||
# get all matched positions
|
||||
match_pos = []
|
||||
for r in rs:
|
||||
i = 0
|
||||
while i >= 0:
|
||||
m = r.search(a, i)
|
||||
if m:
|
||||
match_pos.append([m.start(), m.end()])
|
||||
i = m.end()
|
||||
else:
|
||||
break
|
||||
|
||||
if args.trans_type == "phn":
|
||||
a = a.split(" ")
|
||||
else:
|
||||
if len(match_pos) > 0:
|
||||
chars = []
|
||||
i = 0
|
||||
while i < len(a):
|
||||
start_pos, end_pos = exist_or_not(i, match_pos)
|
||||
if start_pos is not None:
|
||||
chars.append(a[start_pos:end_pos])
|
||||
i = end_pos
|
||||
else:
|
||||
chars.append(a[i])
|
||||
i += 1
|
||||
a = chars
|
||||
|
||||
a = [a[j:j + n] for j in range(0, len(a), n)]
|
||||
|
||||
a_flat = []
|
||||
for z in a:
|
||||
a_flat.append("".join(z))
|
||||
|
||||
a_chars = [z.replace(" ", args.space) for z in a_flat]
|
||||
if args.trans_type == "phn":
|
||||
a_chars = [z.replace("sil", args.space) for z in a_chars]
|
||||
print(" ".join(a_chars))
|
||||
line = f.readline()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in new issue