Merge pull request #929 from PaddlePaddle/join_ctc

[lm] transformer lm & kaldi data process
pull/945/head
Hui Zhang 3 years ago committed by GitHub
commit e8bc9a2a08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

2
.gitignore vendored

@ -24,5 +24,7 @@ tools/montreal-forced-aligner/
tools/Montreal-Forced-Aligner/
tools/sctk
tools/sctk-20159b5/
tools/kaldi
tools/OpenBLAS/
*output/

@ -24,6 +24,7 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
@ -48,6 +49,21 @@ def load_trained_model(args):
model = exp.model
return model, char_list, exp, confs
def get_config(config_path):
stream = open(config_path, mode='r', encoding="utf-8")
config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close()
return config
def load_trained_lm(args):
lm_args = get_config(args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module)
lm = lm_class(lm_args.model)
model_dict = paddle.load(args.rnnlm)
lm.set_state_dict(model_dict)
return lm
def recog_v2(args):
"""Decode with custom models that implements ScorerInterface.
@ -78,12 +94,7 @@ def recog_v2(args):
preprocess_args={"train": False}, )
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(char_list), lm_args)
torch_load(args.rnnlm, lm)
lm = load_trained_lm(args)
lm.eval()
else:
lm = None

@ -21,9 +21,6 @@ from distutils.util import strtobool
import configargparse
import numpy as np
from deepspeech.decoders.recog import recog_v2
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
@ -359,7 +356,7 @@ def main(args):
if args.num_encs == 1:
# Experimental API that supports custom LMs
if args.api == "v2":
from deepspeech.decoders.recog import recog_v2
recog_v2(args)
else:
raise ValueError("Only support --api v2")

@ -318,6 +318,18 @@ class CTCPrefixScore():
r[0, 0] = xs[0]
r[0, 1] = self.logzero
else:
# Although the code does not exactly follow Algorithm 2,
# we don't have to change it because we can assume
# r_t(h)=0 for t < |h| in CTC forward computation
# (Note: we assume here that index t starts with 0).
# The purpose of this difference is to reduce the number of for-loops.
# https://github.com/espnet/espnet/pull/3655
# where we start to accumulate r_t(h) from t=|h|
# and iterate r_t(h) = (r_{t-1}(h) + ...) to T-1,
# avoiding accumulating zeros for t=1~|h|-1.
# Thus, we need to set r_{|h|-1}(h) = 0,
# i.e., r[output_length-1] = logzero, for initialization.
# This is just for reducing the computation.
r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label

@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the data augmentation pipeline."""
import json
import os
from collections.abc import Sequence
from inspect import signature
from pprint import pformat
@ -90,9 +91,8 @@ class AugmentationPipeline():
effect.
Params:
augmentation_config(str): Augmentation configuration in json string.
preprocess_conf(str): Augmentation configuration in `json file` or `json string`.
random_seed(int): Random seed.
train(bool): whether is train mode.
Raises:
ValueError: If the augmentation json config is in incorrect format".
@ -100,11 +100,18 @@ class AugmentationPipeline():
SPEC_TYPES = {'specaug'}
def __init__(self, augmentation_config: str, random_seed: int=0):
def __init__(self, preprocess_conf: str, random_seed: int=0):
self._rng = np.random.RandomState(random_seed)
self.conf = {'mode': 'sequential', 'process': []}
if augmentation_config:
process = json.loads(augmentation_config)
if preprocess_conf:
if os.path.isfile(preprocess_conf):
# json file
with open(preprocess_conf, 'r') as fin:
json_string = fin.read()
else:
# json string
json_string = preprocess_conf
process = json.loads(json_string)
self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all')

@ -105,7 +105,7 @@ class SpeechCollatorBase():
self._local_data = TarLocalData(tar2info={}, tar2object={})
self.augmentation = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed)
preprocess_conf=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None

@ -17,7 +17,7 @@ import kaldiio
import numpy as np
import soundfile
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
from deepspeech.utils.log import Log
__all__ = ["LoadInputsAndTargets"]
@ -66,8 +66,7 @@ class LoadInputsAndTargets():
raise ValueError("Only asr are allowed: mode={}".format(mode))
if preprocess_conf is not None:
with open(preprocess_conf, 'r') as fin:
self.preprocessing = AugmentationPipeline(fin.read())
self.preprocessing = Transformation(preprocess_conf)
logger.warning(
"[Experimental feature] Some preprocessing will be done "
"for the mini-batch creation using {}".format(

@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import
class ASRInterface:
"""ASR Interface for ESPnet model implementation."""
"""ASR Interface model implementation."""
@staticmethod
def add_arguments(parser):
@ -103,14 +103,14 @@ class ASRInterface:
@property
def attention_plot_class(self):
"""Get attention plot class."""
from espnet.asr.asr_utils import PlotAttentionReport
from deepspeech.training.extensions.plot import PlotAttentionReport
return PlotAttentionReport
@property
def ctc_plot_class(self):
"""Get CTC plot class."""
from espnet.asr.asr_utils import PlotCTCReport
from deepspeech.training.extensions.plot import PlotCTCReport
return PlotCTCReport

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,263 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import List
from typing import Tuple
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.models.lm_interface import LMInterface
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.mask import subsequent_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def __init__(
self,
n_vocab: int,
pos_enc: str=None,
embed_unit: int=128,
att_unit: int=256,
head: int=2,
unit: int=1024,
layer: int=4,
dropout_rate: float=0.5,
emb_dropout_rate: float=0.0,
att_dropout_rate: float=0.0,
tie_weights: bool=False,
**kwargs):
nn.Layer.__init__(self)
if pos_enc == "sinusoidal":
pos_enc_layer_type = "abs_pos"
elif pos_enc is None:
pos_enc_layer_type = "no_pos"
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(n_vocab, embed_unit)
if emb_dropout_rate == 0.0:
self.embed_drop = None
else:
self.embed_drop = nn.Dropout(emb_dropout_rate)
self.encoder = TransformerEncoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
attention_dropout_rate=att_dropout_rate,
input_layer="linear",
pos_enc_layer_type=pos_enc_layer_type,
concat_after=False,
static_chunk_size=1,
use_dynamic_chunk=False,
use_dynamic_left_chunk=False)
self.decoder = nn.Linear(att_unit, n_vocab)
logger.info("Tie weights set to {}".format(tie_weights))
logger.info("Dropout set to {}".format(dropout_rate))
logger.info("Emb Dropout set to {}".format(emb_dropout_rate))
logger.info("Att Dropout set to {}".format(att_dropout_rate))
if tie_weights:
assert (
att_unit == embed_unit
), "Tie Weights: True need embedding and final dimensions to match"
self.decoder.weight = self.embed.weight
def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def forward(self, x: paddle.Tensor, t: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute LM loss value from buffer sequences.
Args:
x (paddle.Tensor): Input ids. (batch, len)
t (paddle.Tensor): Target ids. (batch, len)
Returns:
tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
xm = x != 0
xlen = xm.sum(axis=1)
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(x))
else:
emb = self.embed(x)
h, _ = self.encoder(emb, xlen)
y = self.decoder(h)
loss = F.cross_entropy(
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
mask = xm.to(dtype=loss.dtype)
logp = loss * mask.view(-1)
logp = logp.sum()
count = mask.sum()
return logp / count, logp, count
# beam search API (see ScorerInterface)
def score(self, y: paddle.Tensor, state: Any,
x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
"""Score new token.
Args:
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
state: Scorer state for prefix tokens
x (paddle.Tensor): encoder feature that generates ys.
Returns:
tuple[paddle.Tensor, Any]: Tuple of
paddle.float32 scores for next token (n_vocab)
and next state for ys
"""
y = y.unsqueeze(0)
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(y))
else:
emb = self.embed(y)
h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = F.log_softmax(h).squeeze(0)
return logp, cache
# batch beam search API (see BatchScorerInterface)
def batch_score(self,
ys: paddle.Tensor,
states: List[Any],
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
paddle.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(ys))
else:
emb = self.embed(ys)
# batch decoding
h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = F.log_softmax(h)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)]
return logp, state_list
if __name__ == "__main__":
tlm = TransformerLM(
n_vocab=5002,
pos_enc=None,
embed_unit=128,
att_unit=512,
head=8,
unit=2048,
layer=16,
dropout_rate=0.5, )
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
paddle.set_device("cpu")
model_dict = paddle.load("transformerLM.pdparams")
tlm.set_state_dict(model_dict)
tlm.eval()
#Test the score
input2 = np.array([5])
input2 = paddle.to_tensor(input2)
state = None
output, state = tlm.score(input2, state, None)
input3 = np.array([5, 10])
input3 = paddle.to_tensor(input3)
output, state = tlm.score(input3, state, None)
input4 = np.array([5, 10, 0])
input4 = paddle.to_tensor(input4)
output, state = tlm.score(input4, state, None)
print("output", output)
"""
#Test the batch score
batch_size = 2
inp2 = np.array([[5], [10]])
inp2 = paddle.to_tensor(inp2)
output, states = tlm.batch_score(
inp2, [(None,None,0)] * batch_size)
inp3 = np.array([[100], [30]])
inp3 = paddle.to_tensor(inp3)
output, states = tlm.batch_score(
inp3, states)
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""

@ -0,0 +1,82 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language model interface."""
import argparse
from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
from deepspeech.utils.dynamic_import import dynamic_import
class LMInterface(ScorerInterface):
"""LM Interface model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to command line argument parser."""
return parser
@classmethod
def build(cls, n_vocab: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of vocabulary.
Returns:
LMinterface: A new instance of LMInterface.
"""
args = argparse.Namespace(**kwargs)
return cls(n_vocab, args)
def forward(self, x, t):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
raise NotImplementedError("forward method is not implemented")
predefined_lms = {
"transformer": "deepspeech.models.lm.transformer:TransformerLM",
}
def dynamic_import_lm(module):
"""Import LM class dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_lms`
Returns:
type: LM class
"""
model_class = dynamic_import(module, predefined_lms)
assert issubclass(model_class,
LMInterface), f"{module} does not implement LMInterface"
return model_class

@ -0,0 +1,75 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
from .asr_interface import ASRInterface
from deepspeech.utils.dynamic_import import dynamic_import
class STInterface(ASRInterface):
"""ST Interface model implementation.
NOTE: This class is inherited from ASRInterface to enable joint translation
and recognition when performing multi-task learning with the ASR task.
"""
def translate(self,
x,
trans_args,
char_list=None,
rnnlm=None,
ensemble_models=[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("translate method is not implemented")
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
predefined_st = {
"transformer": "deepspeech.models.u2_st:U2STModel",
}
def dynamic_import_st(module):
"""Import ST models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_st`
Returns:
type: ST class
"""
model_class = dynamic_import(module, predefined_st)
assert issubclass(model_class,
STInterface), f"{module} does not implement STInterface"
return model_class

@ -0,0 +1,15 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2_st import U2STInferModel
from .u2_st import U2STModel

@ -22,10 +22,52 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["PositionalEncoding", "RelPositionalEncoding"]
__all__ = [
"PositionalEncodingInterface", "NoPositionalEncoding", "PositionalEncoding",
"RelPositionalEncoding"
]
class PositionalEncoding(nn.Layer):
class PositionalEncodingInterface:
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
raise NotImplementedError("forward method is not implemented")
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
""" For getting encoding in a streaming fashion
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding
"""
raise NotImplementedError("position_encoding method is not implemented")
class NoPositionalEncoding(nn.Layer, PositionalEncodingInterface):
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int=5000,
reverse: bool=False):
nn.Layer.__init__(self)
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
return x, None
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
return None
class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
def __init__(self,
d_model: int,
dropout_rate: float,
@ -40,7 +82,7 @@ class PositionalEncoding(nn.Layer):
max_len (int, optional): maximum input length. Defaults to 5000.
reverse (bool, optional): Not used. Defaults to False.
"""
super().__init__()
nn.Layer.__init__(self)
self.d_model = d_model
self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
@ -85,7 +127,7 @@ class PositionalEncoding(nn.Layer):
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding encoding
paddle.Tensor: Corresponding position encoding
"""
assert offset + size < self.max_len
return self.dropout(self.pe[:, offset:offset + size])

@ -24,6 +24,7 @@ from deepspeech.modules.activation import get_activation
from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.attention import RelPositionMultiHeadedAttention
from deepspeech.modules.conformer_convolution import ConvolutionModule
from deepspeech.modules.embedding import NoPositionalEncoding
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.embedding import RelPositionalEncoding
from deepspeech.modules.encoder_layer import ConformerEncoderLayer
@ -76,7 +77,7 @@ class BaseEncoder(nn.Layer):
input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos]
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool):
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
@ -101,6 +102,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@ -370,6 +373,41 @@ class TransformerEncoder(BaseEncoder):
concat_after=concat_after) for _ in range(num_blocks)
])
def forward_one_step(
self,
xs: paddle.Tensor,
masks: paddle.Tensor,
cache=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encode input frame.
Args:
xs (paddle.Tensor): (Prefix) Input tensor. (B, T, D)
masks (paddle.Tensor): Mask tensor. (B, T, T)
cache (List[paddle.Tensor]): List of cache tensors.
Returns:
paddle.Tensor: Output tensor.
paddle.Tensor: Mask tensor.
List[paddle.Tensor]: List of new cache tensors.
"""
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
if cache is None:
cache = [None for _ in range(len(self.encoders))]
new_cache = []
for c, e in zip(cache, self.encoders):
xs, masks, _ = e(xs, masks, output_cache=c)
new_cache.append(xs)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""

@ -71,7 +71,7 @@ class TransformerEncoderLayer(nn.Layer):
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
pos_emb: Optional[paddle.Tensor]=None,
mask_pad: Optional[paddle.Tensor]=None,
output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None,
@ -82,8 +82,8 @@ class TransformerEncoderLayer(nn.Layer):
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (paddle.Tensor): does not used in transformer layer,
just for unified api with conformer.
mask_pad (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface

@ -60,7 +60,8 @@ class LinearNoSubsampling(BaseSubsampling):
self.out = nn.Sequential(
nn.Linear(idim, odim),
nn.LayerNorm(odim, epsilon=1e-12),
nn.Dropout(dropout_rate), )
nn.Dropout(dropout_rate),
nn.ReLU(), )
self.right_context = 0
self.subsampling_rate = 1
@ -83,7 +84,12 @@ class LinearNoSubsampling(BaseSubsampling):
return x, pos_emb, x_mask
class Conv2dSubsampling4(BaseSubsampling):
class Conv2dSubsampling(BaseSubsampling):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class Conv2dSubsampling4(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/4 length)."""
def __init__(self,
@ -134,7 +140,7 @@ class Conv2dSubsampling4(BaseSubsampling):
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
class Conv2dSubsampling6(BaseSubsampling):
class Conv2dSubsampling6(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/6 length)."""
def __init__(self,
@ -187,7 +193,7 @@ class Conv2dSubsampling6(BaseSubsampling):
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
class Conv2dSubsampling8(BaseSubsampling):
class Conv2dSubsampling8(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/8 length)."""
def __init__(self,

@ -0,0 +1,418 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import numpy as np
from . import extension
class PlotAttentionReport(extension.Extension):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
att_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1, ):
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of att_ws matrix."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
self.outdir, uttid_list[idx], i + 1, )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
self.outdir, uttid_list[idx], i + 1, )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w,
filename.format(trainer))
# han
for idx, att_w in enumerate(att_ws[num_encs]):
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
self.outdir, uttid_list[idx], )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(
att_w, filename.format(trainer), han_mode=True)
else:
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
uttid_list[idx], )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
def log_attentions(self, logger, step):
"""Add image files of att_ws matrix to the tensorboard."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure(
"%s_att%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step, )
# han
for idx, att_w in enumerate(att_ws[num_encs]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_han_plot(att_w)
logger.add_figure(
"%s_han" % (uttid_list[idx]),
plot.gcf(),
step, )
else:
for idx, att_w in enumerate(att_ws):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_attention_weights(self):
"""Return attention weights.
Returns:
numpy.ndarray: attention weights. float. Its shape would be
differ from backend.
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
* chainer-> (B, Lmax, Tmax)
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
else:
att_ws = self.att_vis_fn(**batch)
return att_ws, uttid_list
def trim_attention_weight(self, uttid, att_w):
"""Transform attention matrix with regard to self.reverse."""
if self.reverse:
enc_key, enc_axis = self.okey, self.oaxis
dec_key, dec_axis = self.ikey, self.iaxis
else:
enc_key, enc_axis = self.ikey, self.iaxis
dec_key, dec_axis = self.okey, self.oaxis
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
if len(att_w.shape) == 3:
att_w = att_w[:, :dec_len, :enc_len]
else:
att_w = att_w[:dec_len, :enc_len]
return att_w
def draw_attention_plot(self, att_w):
"""Plot the att_w matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
att_w = att_w.astype(np.float32)
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
plt.subplot(1, len(att_w), h)
plt.imshow(aw, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
else:
plt.imshow(att_w, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
plt.tight_layout()
return plt
def draw_han_plot(self, att_w):
"""Plot the att_w matrix for hierarchical attention.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
legends = []
plt.subplot(1, len(att_w), h)
for i in range(aw.shape[1]):
plt.plot(aw[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, aw.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
else:
legends = []
for i in range(att_w.shape[1]):
plt.plot(att_w[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, att_w.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
plt.tight_layout()
return plt
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
if han_mode:
plt = self.draw_han_plot(att_w)
else:
plt = self.draw_attention_plot(att_w)
plt.savefig(filename)
plt.close()
class PlotCTCReport(extension.Extension):
"""Plot CTC reporter.
Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
ctc_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1, ):
self.ctc_vis_fn = ctc_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of ctc prob."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
self.outdir, uttid_list[idx], i + 1, )
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
self.outdir, uttid_list[idx], i + 1, )
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
else:
for idx, ctc_prob in enumerate(ctc_probs):
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
uttid_list[idx], )
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
def log_ctc_probs(self, logger, step):
"""Add image files of ctc probs to the tensorboard."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure(
"%s_ctc%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step, )
else:
for idx, ctc_prob in enumerate(ctc_probs):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_ctc_probs(self):
"""Return CTC probs.
Returns:
numpy.ndarray: CTC probs. float. Its shape would be
differ from backend. (B, Tmax, vocab).
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
probs = self.ctc_vis_fn(*batch)
else:
probs = self.ctc_vis_fn(**batch)
return probs, uttid_list
def trim_ctc_prob(self, uttid, prob):
"""Trim CTC posteriors accoding to input lengths."""
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
prob = prob[:enc_len]
return prob
def draw_ctc_plot(self, ctc_prob):
"""Plot the ctc_prob matrix.
Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
ctc_prob = ctc_prob.astype(np.float32)
plt.clf()
topk_ids = np.argsort(ctc_prob, axis=1)
n_frames, vocab = ctc_prob.shape
times_probs = np.arange(n_frames)
plt.figure(figsize=(20, 8))
# NOTE: index 0 is reserved for blank
for idx in set(topk_ids.reshape(-1).tolist()):
if idx == 0:
plt.plot(
times_probs,
ctc_prob[:, 0],
":",
label="<blank>",
color="grey")
else:
plt.plot(times_probs, ctc_prob[:, idx])
plt.xlabel(u"Input [frame]", fontsize=12)
plt.ylabel("Posteriors", fontsize=12)
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
plt.yticks(list(range(0, 2, 1)))
plt.tight_layout()
return plt
def _plot_and_save_ctc(self, ctc_prob, filename):
plt = self.draw_ctc_plot(ctc_prob)
plt.savefig(filename)
plt.close()

@ -11,18 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger

@ -0,0 +1,61 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..reporter import DictSummary
from .utils import get_trigger
class CompareValueTrigger():
"""Trigger invoked when key value getting bigger or lower than before.
Args:
key (str) : Key of value.
compare_fn ((float, float) -> bool) : Function to compare the values.
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
self._key = key
self._best_value = None
self._interval_trigger = get_trigger(trigger)
self._init_summary()
self._compare_fn = compare_fn
def __call__(self, trainer):
"""Get value related to the key and compare with current value."""
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
self._init_summary()
if self._best_value is None:
# initialize best value
self._best_value = value
return False
elif self._compare_fn(self._best_value, value):
return True
else:
self._best_value = value
return False
def _init_summary(self):
self._summary = DictSummary()

@ -30,3 +30,12 @@ class TimeTrigger():
return True
else:
return False
def state_dict(self):
state_dict = {
"next_time": self._next_time,
}
return state_dict
def set_state_dict(self, state_dict):
self._next_time = state_dict['next_time']

@ -0,0 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,41 @@
import numpy as np
def delta(feat, window):
assert window > 0
delta_feat = np.zeros_like(feat)
for i in range(1, window + 1):
delta_feat[:-i] += i * feat[i:]
delta_feat[i:] += -i * feat[:-i]
delta_feat[-i:] += i * feat[-1]
delta_feat[:i] += -i * feat[0]
delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1))
return delta_feat
def add_deltas(x, window=2, order=2):
"""
Args:
x (np.ndarray): speech feat, (T, D).
Return:
np.ndarray: (T, (1+order)*D)
"""
feats = [x]
for _ in range(order):
feats.append(delta(feats[-1], window))
return np.concatenate(feats, axis=1)
class AddDeltas():
def __init__(self, window=2, order=2):
self.window = window
self.order = order
def __repr__(self):
return "{name}(window={window}, order={order}".format(
name=self.__class__.__name__, window=self.window, order=self.order
)
def __call__(self, x):
return add_deltas(x, window=self.window, order=self.order)

@ -0,0 +1,45 @@
import numpy
class ChannelSelector():
"""Select 1ch from multi-channel signal"""
def __init__(self, train_channel="random", eval_channel=0, axis=1):
self.train_channel = train_channel
self.eval_channel = eval_channel
self.axis = axis
def __repr__(self):
return (
"{name}(train_channel={train_channel}, "
"eval_channel={eval_channel}, axis={axis})".format(
name=self.__class__.__name__,
train_channel=self.train_channel,
eval_channel=self.eval_channel,
axis=self.axis,
)
)
def __call__(self, x, train=True):
# Assuming x: [Time, Channel] by default
if x.ndim <= self.axis:
# If the dimension is insufficient, then unsqueeze
# (e.g [Time] -> [Time, 1])
ind = tuple(
slice(None) if i < x.ndim else None for i in range(self.axis + 1)
)
x = x[ind]
if train:
channel = self.train_channel
else:
channel = self.eval_channel
if channel == "random":
ch = numpy.random.randint(0, x.shape[self.axis])
else:
ch = channel
ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim))
return x[ind]

@ -0,0 +1,158 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import h5py
import kaldiio
import numpy as np
class CMVN():
"Apply Global/Spk CMVN/iverserCMVN."
def __init__(
self,
stats,
norm_means=True,
norm_vars=False,
filetype="mat",
utt2spk=None,
spk2utt=None,
reverse=False,
std_floor=1.0e-20, ):
self.stats_file = stats
self.norm_means = norm_means
self.norm_vars = norm_vars
self.reverse = reverse
if isinstance(stats, dict):
stats_dict = dict(stats)
else:
# Use for global CMVN
if filetype == "mat":
stats_dict = {None: kaldiio.load_mat(stats)}
# Use for global CMVN
elif filetype == "npy":
stats_dict = {None: np.load(stats)}
# Use for speaker CMVN
elif filetype == "ark":
self.accept_uttid = True
stats_dict = dict(kaldiio.load_ark(stats))
# Use for speaker CMVN
elif filetype == "hdf5":
self.accept_uttid = True
stats_dict = h5py.File(stats)
else:
raise ValueError("Not supporting filetype={}".format(filetype))
if utt2spk is not None:
self.utt2spk = {}
with io.open(utt2spk, "r", encoding="utf-8") as f:
for line in f:
utt, spk = line.rstrip().split(None, 1)
self.utt2spk[utt] = spk
elif spk2utt is not None:
self.utt2spk = {}
with io.open(spk2utt, "r", encoding="utf-8") as f:
for line in f:
spk, utts = line.rstrip().split(None, 1)
for utt in utts.split():
self.utt2spk[utt] = spk
else:
self.utt2spk = None
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
# and the first vector contains the sum of feats and the second is
# the sum of squares. The last value of the first, i.e. stats[0,-1],
# is the number of samples for this statistics.
self.bias = {}
self.scale = {}
for spk, stats in stats_dict.items():
assert len(stats) == 2, stats.shape
count = stats[0, -1]
# If the feature has two or more dimensions
if not (np.isscalar(count) or isinstance(count, (int, float))):
# The first is only used
count = count.flatten()[0]
mean = stats[0, :-1] / count
# V(x) = E(x^2) - (E(x))^2
var = stats[1, :-1] / count - mean * mean
std = np.maximum(np.sqrt(var), std_floor)
self.bias[spk] = -mean
self.scale[spk] = 1 / std
def __repr__(self):
return ("{name}(stats_file={stats_file}, "
"norm_means={norm_means}, norm_vars={norm_vars}, "
"reverse={reverse})".format(
name=self.__class__.__name__,
stats_file=self.stats_file,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
reverse=self.reverse, ))
def __call__(self, x, uttid=None):
if self.utt2spk is not None:
spk = self.utt2spk[uttid]
else:
spk = uttid
if not self.reverse:
# apply cmvn
if self.norm_means:
x = np.add(x, self.bias[spk])
if self.norm_vars:
x = np.multiply(x, self.scale[spk])
else:
# apply reverse cmvn
if self.norm_vars:
x = np.divide(x, self.scale[spk])
if self.norm_means:
x = np.subtract(x, self.bias[spk])
return x
class UtteranceCMVN():
"Apply Utterance CMVN"
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
def __repr__(self):
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
name=self.__class__.__name__,
norm_means=self.norm_means,
norm_vars=self.norm_vars, )
def __call__(self, x, uttid=None):
# x: [Time, Dim]
square_sums = (x**2).sum(axis=0)
mean = x.mean(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean**2
std = np.maximum(np.sqrt(var), self.std_floor)
x = np.divide(x, std)
return x

@ -0,0 +1,71 @@
import inspect
from deepspeech.transform.transform_interface import TransformInterface
from deepspeech.utils.check_kwargs import check_kwargs
class FuncTrans(TransformInterface):
"""Functional Transformation
WARNING:
Builtin or C/C++ functions may not work properly
because this class heavily depends on the `inspect` module.
Usage:
>>> def foo_bar(x, a=1, b=2):
... '''Foo bar
... :param x: input
... :param int a: default 1
... :param int b: default 2
... '''
... return x + a - b
>>> class FooBar(FuncTrans):
... _func = foo_bar
... __doc__ = foo_bar.__doc__
"""
_func = None
def __init__(self, **kwargs):
self.kwargs = kwargs
check_kwargs(self.func, kwargs)
def __call__(self, x):
return self.func(x, **self.kwargs)
@classmethod
def add_arguments(cls, parser):
fname = cls._func.__name__.replace("_", "-")
group = parser.add_argument_group(fname + " transformation setting")
for k, v in cls.default_params().items():
# TODO(karita): get help and choices from docstring?
attr = k.replace("_", "-")
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
return parser
@property
def func(self):
return type(self)._func
@classmethod
def default_params(cls):
try:
d = dict(inspect.signature(cls._func).parameters)
except ValueError:
d = dict()
return {
k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty
}
def __repr__(self):
params = self.default_params()
params.update(**self.kwargs)
ret = self.__class__.__name__ + "("
if len(params) == 0:
return ret + ")"
for k, v in params.items():
ret += "{}={}, ".format(k, v)
return ret[:-2] + ")"

@ -0,0 +1,343 @@
import librosa
import numpy
import scipy
import soundfile
from deepspeech.io.reader import SoundHDF5File
class SpeedPerturbation():
"""SpeedPerturbation
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
Warning:
This function is very slow because of resampling.
I recommmend to apply speed-perturb outside the training using sox.
"""
def __init__(
self,
lower=0.9,
upper=1.1,
utt2ratio=None,
keep_length=True,
res_type="kaiser_best",
seed=None,
):
self.res_type = res_type
self.keep_length = keep_length
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
self.utt2ratio = {}
# Use the scheduled ratio for each utterances
self.utt2ratio_file = utt2ratio
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
self.utt2ratio = None
# The ratio is given on runtime randomly
self.lower = lower
self.upper = upper
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
self.__class__.__name__,
self.lower,
self.upper,
self.keep_length,
self.res_type,
)
else:
return "{}({}, res_type={})".format(
self.__class__.__name__, self.utt2ratio_file, self.res_type
)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
# Note1: resample requires the sampling-rate of input and output,
# but actually only the ratio is used.
y = librosa.resample(x, ratio, 1, res_type=self.res_type)
if self.keep_length:
diff = abs(len(x) - len(y))
if len(y) > len(x):
# Truncate noise
y = y[diff // 2 : -((diff + 1) // 2)]
elif len(y) < len(x):
# Assume the time-axis is the first: (Time, Channel)
pad_width = [(diff // 2, (diff + 1) // 2)] + [
(0, 0) for _ in range(y.ndim - 1)
]
y = numpy.pad(
y, pad_width=pad_width, constant_values=0, mode="constant"
)
return y
class BandpassPerturbation():
"""BandpassPerturbation
Randomly dropout along the frequency axis.
The original idea comes from the following:
"randomly-selected frequency band was cut off under the constraint of
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
everyday home environments using multiple microphone arrays;
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
"""
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)):
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
# x_stft: (Time, Channel, Freq)
self.axes = axes
def __repr__(self):
return "{}(lower={}, upper={})".format(
self.__class__.__name__, self.lower, self.upper
)
def __call__(self, x_stft, uttid=None, train=True):
if not train:
return x_stft
if x_stft.ndim == 1:
raise RuntimeError(
"Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)"
)
ratio = self.state.uniform(self.lower, self.upper)
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
mask = self.state.randn(*shape) > ratio
x_stft *= mask
return x_stft
class VolumePerturbation():
def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None):
self.dbunit = dbunit
self.utt2ratio_file = utt2ratio
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
# Use the scheduled ratio for each utterances
self.utt2ratio = {}
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
# The ratio is given on runtime randomly
self.utt2ratio = None
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit
)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit
)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10 ** (ratio / 20)
return x * ratio
class NoiseInjection():
"""Add isotropic noise"""
def __init__(
self,
utt2noise=None,
lower=-20,
upper=-5,
utt2ratio=None,
filetype="list",
dbunit=True,
seed=None,
):
self.utt2noise_file = utt2noise
self.utt2ratio_file = utt2ratio
self.filetype = filetype
self.dbunit = dbunit
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
# Use the scheduled ratio for each utterances
self.utt2ratio = {}
with open(utt2noise, "r") as f:
for line in f:
utt, snr = line.rstrip().split(None, 1)
snr = float(snr)
self.utt2ratio[utt] = snr
else:
# The ratio is given on runtime randomly
self.utt2ratio = None
if utt2noise is not None:
self.utt2noise = {}
if filetype == "list":
with open(utt2noise, "r") as f:
for line in f:
utt, filename = line.rstrip().split(None, 1)
signal, rate = soundfile.read(filename, dtype="int16")
# Load all files in memory
self.utt2noise[utt] = (signal, rate)
elif filetype == "sound.hdf5":
self.utt2noise = SoundHDF5File(utt2noise, "r")
else:
raise ValueError(filetype)
else:
self.utt2noise = None
if utt2noise is not None and utt2ratio is not None:
if set(self.utt2ratio) != set(self.utt2noise):
raise RuntimeError(
"The uttids mismatch between {} and {}".format(utt2ratio, utt2noise)
)
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit
)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit
)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
# 1. Get ratio of noise to signal in sound pressure level
if uttid is not None and self.utt2ratio is not None:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10 ** (ratio / 20)
scale = ratio * numpy.sqrt((x ** 2).mean())
# 2. Get noise
if self.utt2noise is not None:
# Get noise from the external source
if uttid is not None:
noise, rate = self.utt2noise[uttid]
else:
# Randomly select the noise source
noise = self.state.choice(list(self.utt2noise.values()))
# Normalize the level
noise /= numpy.sqrt((noise ** 2).mean())
# Adjust the noise length
diff = abs(len(x) - len(noise))
offset = self.state.randint(0, diff)
if len(noise) > len(x):
# Truncate noise
noise = noise[offset : -(diff - offset)]
else:
noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap")
else:
# Generate white noise
noise = self.state.normal(0, 1, x.shape)
# 3. Add noise to signal
return x + noise * scale
class RIRConvolve():
def __init__(self, utt2rir, filetype="list"):
self.utt2rir_file = utt2rir
self.filetype = filetype
self.utt2rir = {}
if filetype == "list":
with open(utt2rir, "r") as f:
for line in f:
utt, filename = line.rstrip().split(None, 1)
signal, rate = soundfile.read(filename, dtype="int16")
self.utt2rir[utt] = (signal, rate)
elif filetype == "sound.hdf5":
self.utt2rir = SoundHDF5File(utt2rir, "r")
else:
raise NotImplementedError(filetype)
def __repr__(self):
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if x.ndim != 1:
# Must be single channel
raise RuntimeError(
"Input x must be one dimensional array, but got {}".format(x.shape)
)
rir, rate = self.utt2rir[uttid]
if rir.ndim == 2:
# FIXME(kamo): Use chainer.convolution_1d?
# return [Time, Channel]
return numpy.stack(
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1
)
else:
return scipy.convolve(x, rir, mode="same")

@ -0,0 +1,202 @@
"""Spec Augment module for preprocessing i.e., data augmentation"""
import random
import numpy
from PIL import Image
from PIL.Image import BICUBIC
from deepspeech.transform.functional import FuncTrans
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
:param numpy.ndarray x: spectrogram (time, freq)
:param int max_time_warp: maximum time frames to warp
:param bool inplace: overwrite x with the result
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
(slow, differentiable)
:returns numpy.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC)
if inplace:
x[:warped] = left
x[warped:] = right
return x
return numpy.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
import paddle
from espnet.utils import spec_augment
# TODO(karita): make this differentiable again
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
else:
raise NotImplementedError(
"unknown resize mode: "
+ mode
+ ", choose one from (PIL, sparse_image_warp)."
)
class TimeWarp(FuncTrans):
_func = time_warp
__doc__ = time_warp.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
"""freq mask for spec agument
:param numpy.ndarray x: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if inplace:
cloned = x
else:
cloned = x.copy()
num_mel_channels = cloned.shape[1]
fs = numpy.random.randint(0, F, size=(n_mask, 2))
for f, mask_end in fs:
f_zero = random.randrange(0, num_mel_channels - f)
mask_end += f_zero
# avoids randrange error if values are equal and range is empty
if f_zero == f_zero + f:
continue
if replace_with_zero:
cloned[:, f_zero:mask_end] = 0
else:
cloned[:, f_zero:mask_end] = cloned.mean()
return cloned
class FreqMask(FuncTrans):
_func = freq_mask
__doc__ = freq_mask.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
"""freq mask for spec agument
:param numpy.ndarray spec: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if inplace:
cloned = spec
else:
cloned = spec.copy()
len_spectro = cloned.shape[0]
ts = numpy.random.randint(0, T, size=(n_mask, 2))
for t, mask_end in ts:
# avoid randint range error
if len_spectro - t <= 0:
continue
t_zero = random.randrange(0, len_spectro - t)
# avoids randrange error if values are equal and range is empty
if t_zero == t_zero + t:
continue
mask_end += t_zero
if replace_with_zero:
cloned[t_zero:mask_end] = 0
else:
cloned[t_zero:mask_end] = cloned.mean()
return cloned
class TimeMask(FuncTrans):
_func = time_mask
__doc__ = time_mask.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def spec_augment(
x,
resize_mode="PIL",
max_time_warp=80,
max_freq_width=27,
n_freq_mask=2,
max_time_width=100,
n_time_mask=2,
inplace=True,
replace_with_zero=True,
):
"""spec agument
apply random time warping and time/freq masking
default setting is based on LD (Librispeech double) in Table 2
https://arxiv.org/pdf/1904.08779.pdf
:param numpy.ndarray x: (time, freq)
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
(slow, differentiable)
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
:param int freq_mask_width: maximum width of the random freq mask (F)
:param int n_freq_mask: the number of the random freq mask (m_F)
:param int time_mask_width: maximum width of the random time mask (T)
:param int n_time_mask: the number of the random time mask (m_T)
:param bool inplace: overwrite intermediate array
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
assert isinstance(x, numpy.ndarray)
assert x.ndim == 2
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
x = freq_mask(
x,
max_freq_width,
n_freq_mask,
inplace=inplace,
replace_with_zero=replace_with_zero,
)
x = time_mask(
x,
max_time_width,
n_time_mask,
inplace=inplace,
replace_with_zero=replace_with_zero,
)
return x
class SpecAugment(FuncTrans):
_func = spec_augment
__doc__ = spec_augment.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)

@ -0,0 +1,307 @@
import librosa
import numpy as np
def stft(
x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect"
):
# x: [Time, Channel]
if x.ndim == 1:
single_channel = True
# x: [Time] -> [Time, Channel]
x = x[:, None]
else:
single_channel = False
x = x.astype(np.float32)
# FIXME(kamo): librosa.stft can't use multi-channel?
# x: [Time, Channel, Freq]
x = np.stack(
[
librosa.stft(
x[:, ch],
n_fft=n_fft,
hop_length=n_shift,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
).T
for ch in range(x.shape[1])
],
axis=1,
)
if single_channel:
# x: [Time, Channel, Freq] -> [Time, Freq]
x = x[:, 0]
return x
def istft(x, n_shift, win_length=None, window="hann", center=True):
# x: [Time, Channel, Freq]
if x.ndim == 2:
single_channel = True
# x: [Time, Freq] -> [Time, Channel, Freq]
x = x[:, None, :]
else:
single_channel = False
# x: [Time, Channel]
x = np.stack(
[
librosa.istft(
x[:, ch].T, # [Time, Freq] -> [Freq, Time]
hop_length=n_shift,
win_length=win_length,
window=window,
center=center,
)
for ch in range(x.shape[1])
],
axis=1,
)
if single_channel:
# x: [Time, Channel] -> [Time]
x = x[:, 0]
return x
def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
# x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
# spc: (Time, Channel, Freq) or (Time, Freq)
spc = np.abs(x_stft)
# mel_basis: (Mel_freq, Freq)
mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax)
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
return lmspc
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
return spc
def logmelspectrogram(
x,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10,
pad_mode="reflect",
):
# stft: (Time, Channel, Freq) or (Time, Freq)
x_stft = stft(
x,
n_fft=n_fft,
n_shift=n_shift,
win_length=win_length,
window=window,
pad_mode=pad_mode,
)
return stft2logmelspectrogram(
x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps
)
class Spectrogram():
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
def __repr__(self):
return (
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
)
def __call__(self, x):
return spectrogram(
x,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
class LogMelSpectrogram():
def __init__(
self,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10,
):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps,
)
)
def __call__(self, x):
return logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
class Stft2LogMelSpectrogram():
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps,
)
)
def __call__(self, x):
return stft2logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
)
class Stft():
def __init__(
self,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect",
):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
self.pad_mode = pad_mode
def __repr__(self):
return (
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode,
)
)
def __call__(self, x):
return stft(
x,
self.n_fft,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode,
)
class IStft():
def __init__(self, n_shift, win_length=None, window="hann", center=True):
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
def __repr__(self):
return (
"{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center})".format(
name=self.__class__.__name__,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
)
)
def __call__(self, x):
return istft(
x,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
)

@ -0,0 +1,20 @@
# TODO(karita): add this to all the transform impl.
class TransformInterface:
"""Transform Interface"""
def __call__(self, x):
raise NotImplementedError("__call__ method is not implemented")
@classmethod
def add_arguments(cls, parser):
return parser
def __repr__(self):
return self.__class__.__name__ + "()"
class Identity(TransformInterface):
"""Identity Function"""
def __call__(self, x):
return x

@ -0,0 +1,149 @@
"""Transformation module."""
from collections.abc import Sequence
from collections import OrderedDict
import copy
from inspect import signature
import io
import logging
import yaml
from deepspeech.utils.dynamic_import import dynamic_import
# TODO(karita): inherit TransformInterface
# TODO(karita): register cmd arguments in asr_train.py
import_alias = dict(
identity="deepspeech.transform.transform_interface:Identity",
time_warp="deepspeech.transform.spec_augment:TimeWarp",
time_mask="deepspeech.transform.spec_augment:TimeMask",
freq_mask="deepspeech.transform.spec_augment:FreqMask",
spec_augment="deepspeech.transform.spec_augment:SpecAugment",
speed_perturbation="deepspeech.transform.perturb:SpeedPerturbation",
volume_perturbation="deepspeech.transform.perturb:VolumePerturbation",
noise_injection="deepspeech.transform.perturb:NoiseInjection",
bandpass_perturbation="deepspeech.transform.perturb:BandpassPerturbation",
rir_convolve="deepspeech.transform.perturb:RIRConvolve",
delta="deepspeech.transform.add_deltas:AddDeltas",
cmvn="deepspeech.transform.cmvn:CMVN",
utterance_cmvn="deepspeech.transform.cmvn:UtteranceCMVN",
fbank="deepspeech.transform.spectrogram:LogMelSpectrogram",
spectrogram="deepspeech.transform.spectrogram:Spectrogram",
stft="deepspeech.transform.spectrogram:Stft",
istft="deepspeech.transform.spectrogram:IStft",
stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="deepspeech.transform.wpe:WPE",
channel_selector="deepspeech.transform.channel_selector:ChannelSelector",
)
class Transformation():
"""Apply some functions to the mini-batch
Examples:
>>> kwargs = {"process": [{"type": "fbank",
... "n_mels": 80,
... "fs": 16000},
... {"type": "cmvn",
... "stats": "data/train/cmvn.ark",
... "norm_vars": True},
... {"type": "delta", "window": 2, "order": 2}]}
>>> transform = Transformation(kwargs)
>>> bs = 10
>>> xs = [np.random.randn(100, 80).astype(np.float32)
... for _ in range(bs)]
>>> xs = transform(xs)
"""
def __init__(self, conffile=None):
if conffile is not None:
if isinstance(conffile, dict):
self.conf = copy.deepcopy(conffile)
else:
with io.open(conffile, encoding="utf-8") as f:
self.conf = yaml.safe_load(f)
assert isinstance(self.conf, dict), type(self.conf)
else:
self.conf = {"mode": "sequential", "process": []}
self.functions = OrderedDict()
if self.conf.get("mode", "sequential") == "sequential":
for idx, process in enumerate(self.conf["process"]):
assert isinstance(process, dict), type(process)
opts = dict(process)
process_type = opts.pop("type")
class_obj = dynamic_import(process_type, import_alias)
# TODO(karita): assert issubclass(class_obj, TransformInterface)
try:
self.functions[idx] = class_obj(**opts)
except TypeError:
try:
signa = signature(class_obj)
except ValueError:
# Some function, e.g. built-in function, are failed
pass
else:
logging.error(
"Expected signature: {}({})".format(
class_obj.__name__, signa
)
)
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"])
)
def __repr__(self):
rep = "\n" + "\n".join(
" {}: {}".format(k, v) for k, v in self.functions.items()
)
return "{}({})".format(self.__class__.__name__, rep)
def __call__(self, xs, uttid_list=None, **kwargs):
"""Return new mini-batch
:param Union[Sequence[np.ndarray], np.ndarray] xs:
:param Union[Sequence[str], str] uttid_list:
:return: batch:
:rtype: List[np.ndarray]
"""
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx in range(len(self.conf["process"])):
func = self.functions[idx]
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logging.fatal(
"Catch a exception from {}th func: {}".format(idx, func)
)
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"])
)
if is_batch:
return xs
else:
return xs[0]

@ -0,0 +1,45 @@
from nara_wpe.wpe import wpe
class WPE(object):
def __init__(
self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full"
):
self.taps = taps
self.delay = delay
self.iterations = iterations
self.psd_context = psd_context
self.statistics_mode = statistics_mode
def __repr__(self):
return (
"{name}(taps={taps}, delay={delay}"
"iterations={iterations}, psd_context={psd_context}, "
"statistics_mode={statistics_mode})".format(
name=self.__class__.__name__,
taps=self.taps,
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode,
)
)
def __call__(self, xs):
"""Return enhanced
:param np.ndarray xs: (Time, Channel, Frequency)
:return: enhanced_xs
:rtype: np.ndarray
"""
# nara_wpe.wpe: (F, C, T)
xs = wpe(
xs.transpose((2, 1, 0)),
taps=self.taps,
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode,
)
return xs.transpose(2, 1, 0)

@ -0,0 +1,52 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
__all__ = ["label_smoothing_dist"]
# TODO(takaaki-hori): add different smoothing methods
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
"""Obtain label distribution for loss smoothing.
:param odim:
:param lsm_type:
:param blank:
:param transcript:
:return:
"""
if transcript is not None:
with open(transcript, "rb") as f:
trans_json = json.load(f)["utts"]
if lsm_type == "unigram":
assert transcript is not None, (
"transcript is required for %s label smoothing" % lsm_type)
labelcount = np.zeros(odim)
for k, v in trans_json.items():
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
# to avoid an error when there is no text in an uttrance
if len(ids) > 0:
labelcount[ids] += 1
labelcount[odim - 1] = len(transcript) # count <eos>
labelcount[labelcount == 0] = 1 # flooring
labelcount[blank] = 0 # remove counts for blank
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
else:
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
sys.exit()
return labeldist

@ -14,17 +14,17 @@
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import nltk
import numpy as np
import sacrebleu
__all__ = ['bleu', 'char_bleu']
__all__ = ['bleu', 'char_bleu', "ErrorCalculator"]
def bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in word-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
@ -39,8 +39,6 @@ def char_bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in char-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
@ -52,3 +50,70 @@ def char_bleu(hypothesis, reference):
for ref in reference]
return sacrebleu.corpus_bleu(hypothesis, reference)
class ErrorCalculator():
"""Calculate BLEU for ST and MT models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list: vocabulary list
:param sym_space: space symbol
:param sym_pad: pad symbol
:param report_bleu: report BLUE score if True
"""
def __init__(self, char_list, sym_space, sym_pad, report_bleu=False):
"""Construct an ErrorCalculator object."""
super().__init__()
self.char_list = char_list
self.space = sym_space
self.pad = sym_pad
self.report_bleu = report_bleu
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad):
"""Calculate corpus-level BLEU score.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:return: corpus-level BLEU score in a mini-batch
:rtype float
"""
bleu = None
if not self.report_bleu:
return bleu
bleu = self.calculate_corpus_bleu(ys_hat, ys_pad)
return bleu
def calculate_corpus_bleu(self, ys_hat, ys_pad):
"""Calculate corpus-level BLEU score in a mini-batch.
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
:param torch.Tensor seqs_true: reference (batch, seqlen)
:return: corpus-level BLEU score
:rtype float
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.pad, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true],
seqs_hat)
return bleu * 100

@ -0,0 +1,20 @@
import inspect
def check_kwargs(func, kwargs, name=None):
"""check kwargs are valid for func
If kwargs are invalid, raise TypeError as same as python default
:param function func: function to be validated
:param dict kwargs: keyword arguments for func
:param str name: name used in TypeError (default is func name)
"""
try:
params = inspect.signature(func).parameters
except ValueError:
return
if name is None:
name = func.__name__
for k in kwargs.keys():
if k not in params:
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'")

@ -0,0 +1,241 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import sys
import h5py
import kaldiio
import soundfile
from deepspeech.io.reader import SoundHDF5File
def file_reader_helper(
rspecifier: str,
filetype: str="mat",
return_shape: bool=False,
segments: str=None, ):
"""Read uttid and array in kaldi style
This function might be a bit confusing as "ark" is used
for HDF5 to imitate "kaldi-rspecifier".
Args:
rspecifier: Give as "ark:feats.ark" or "scp:feats.scp"
filetype: "mat" is kaldi-martix, "hdf5": HDF5
return_shape: Return the shape of the matrix,
instead of the matrix. This can reduce IO cost for HDF5.
segments (str): The file format is
"<segment-id> <recording-id> <start-time> <end-time>\n"
"e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5\n"
Returns:
Generator[Tuple[str, np.ndarray], None, None]:
Examples:
Read from kaldi-matrix ark file:
>>> for u, array in file_reader_helper('ark:feats.ark', 'mat'):
... array
Read from HDF5 file:
>>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'):
... array
"""
if filetype == "mat":
return KaldiReader(
rspecifier, return_shape=return_shape, segments=segments)
elif filetype == "hdf5":
return HDF5Reader(rspecifier, return_shape=return_shape)
elif filetype == "sound.hdf5":
return SoundHDF5Reader(rspecifier, return_shape=return_shape)
elif filetype == "sound":
return SoundReader(rspecifier, return_shape=return_shape)
else:
raise NotImplementedError(f"filetype={filetype}")
class KaldiReader:
def __init__(self, rspecifier, return_shape=False, segments=None):
self.rspecifier = rspecifier
self.return_shape = return_shape
self.segments = segments
def __iter__(self):
with kaldiio.ReadHelper(
self.rspecifier, segments=self.segments) as reader:
for key, array in reader:
if self.return_shape:
array = array.shape
yield key, array
class HDF5Reader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'.
format(self.rspecifier))
self.rspecifier = rspecifier
self.ark_or_scp, self.filepath = self.rspecifier.split(":", 1)
if self.ark_or_scp not in ["ark", "scp"]:
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
self.return_shape = return_shape
def __iter__(self):
if self.ark_or_scp == "scp":
hdf5_dict = {}
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ":" not in value:
raise RuntimeError(
"scp file for hdf5 should be like: "
'"uttid filepath.h5:key": {}({})'.format(
line, self.filepath))
path, h5_key = value.split(":", 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = h5py.File(path, "r")
except Exception:
logging.error("Error when loading {}".format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error("Error when loading {} with key={}".
format(path, h5_key))
raise
if self.return_shape:
yield key, data.shape
else:
yield key, data[()]
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == "-":
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
with h5py.File(filepath, "r") as f:
for key in f:
if self.return_shape:
yield key, f[key].shape
else:
yield key, f[key][()]
class SoundHDF5Reader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'.
format(rspecifier))
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
if self.ark_or_scp not in ["ark", "scp"]:
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
self.return_shape = return_shape
def __iter__(self):
if self.ark_or_scp == "scp":
hdf5_dict = {}
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ":" not in value:
raise RuntimeError(
"scp file for hdf5 should be like: "
'"uttid filepath.h5:key": {}({})'.format(
line, self.filepath))
path, h5_key = value.split(":", 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = SoundHDF5File(path, "r")
except Exception:
logging.error("Error when loading {}".format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error("Error when loading {} with key={}".
format(path, h5_key))
raise
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
array, rate = data
if self.return_shape:
array = array.shape
yield key, (rate, array)
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == "-":
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
for key, (a, r) in SoundHDF5File(filepath, "r").items():
if self.return_shape:
a = a.shape
yield key, (r, a)
class SoundReader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError('Give "rspecifier" such as "scp:some.scp: {}"'.
format(rspecifier))
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
if self.ark_or_scp != "scp":
raise ValueError('Only supporting "scp" for sound file: {}'.format(
self.ark_or_scp))
self.return_shape = return_shape
def __iter__(self):
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, sound_file_path = line.rstrip().split(None, 1)
# Assume PCM16
array, rate = soundfile.read(sound_file_path, dtype="int16")
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
if self.return_shape:
array = array.shape
yield key, (rate, array)

@ -0,0 +1,70 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections.abc import Sequence
from distutils.util import strtobool as dist_strtobool
import numpy
def strtobool(x):
# distutils.util.strtobool returns integer, but it's confusing,
return bool(dist_strtobool(x))
def get_commandline_args():
extra_chars = [
" ",
";",
"&",
"(",
")",
"|",
"^",
"<",
">",
"?",
"*",
"[",
"]",
"$",
"`",
'"',
"\\",
"!",
"{",
"}",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''") if all(char not in arg
for char in extra_chars) else
"'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv
]
return sys.executable + " " + " ".join(argv)
def is_scipy_wav_style(value):
# If Tuple[int, numpy.ndarray] or not
return (isinstance(value, Sequence) and len(value) == 2 and
isinstance(value[0], int) and isinstance(value[1], numpy.ndarray))
def assert_scipy_wav_style(value):
assert is_scipy_wav_style(
value), "Must be Tuple[int, numpy.ndarray], but got {}".format(
type(value) if not isinstance(value, Sequence) else "{}[{}]".format(
type(value), ", ".join(str(type(v)) for v in value)))

@ -0,0 +1,293 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict
import h5py
import kaldiio
import numpy
import soundfile
from deepspeech.io.reader import SoundHDF5File
from deepspeech.utils.cli_utils import assert_scipy_wav_style
def file_writer_helper(
wspecifier: str,
filetype: str="mat",
write_num_frames: str=None,
compress: bool=False,
compression_method: int=2,
pcm_format: str="wav", ):
"""Write matrices in kaldi style
Args:
wspecifier: e.g. ark,scp:out.ark,out.scp
filetype: "mat" is kaldi-martix, "hdf5": HDF5
write_num_frames: e.g. 'ark,t:num_frames.txt'
compress: Compress or not
compression_method: Specify compression level
Write in kaldi-matrix-ark with "kaldi-scp" file:
>>> with file_writer_helper('ark,scp:out.ark,out.scp') as f:
>>> f['uttid'] = array
This "scp" has the following format:
uttidA out.ark:1234
uttidB out.ark:2222
where, 1234 and 2222 points the strating byte address of the matrix.
(For detail, see official documentation of Kaldi)
Write in HDF5 with "scp" file:
>>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f:
>>> f['uttid'] = array
This "scp" file is created as:
uttidA out.h5:uttidA
uttidB out.h5:uttidB
HDF5 can be, unlike "kaldi-ark", accessed to any keys,
so originally "scp" is not required for random-reading.
Nevertheless we create "scp" for HDF5 because it is useful
for some use-case. e.g. Concatenation, Splitting.
"""
if filetype == "mat":
return KaldiWriter(
wspecifier,
write_num_frames=write_num_frames,
compress=compress,
compression_method=compression_method, )
elif filetype == "hdf5":
return HDF5Writer(
wspecifier, write_num_frames=write_num_frames, compress=compress)
elif filetype == "sound.hdf5":
return SoundHDF5Writer(
wspecifier,
write_num_frames=write_num_frames,
pcm_format=pcm_format)
elif filetype == "sound":
return SoundWriter(
wspecifier,
write_num_frames=write_num_frames,
pcm_format=pcm_format)
else:
raise NotImplementedError(f"filetype={filetype}")
class BaseWriter:
def __setitem__(self, key, value):
raise NotImplementedError
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
try:
self.writer.close()
except Exception:
pass
if self.writer_scp is not None:
try:
self.writer_scp.close()
except Exception:
pass
if self.writer_nframe is not None:
try:
self.writer_nframe.close()
except Exception:
pass
def get_num_frames_writer(write_num_frames: str):
"""get_num_frames_writer
Examples:
>>> get_num_frames_writer('ark,t:num_frames.txt')
"""
if write_num_frames is not None:
if ":" not in write_num_frames:
raise ValueError('Must include ":", write_num_frames={}'.format(
write_num_frames))
nframes_type, nframes_file = write_num_frames.split(":", 1)
if nframes_type != "ark,t":
raise ValueError("Only supporting text mode. "
"e.g. --write-num-frames=ark,t:foo.txt :"
"{}".format(nframes_type))
return open(nframes_file, "w", encoding="utf-8")
class KaldiWriter(BaseWriter):
def __init__(self,
wspecifier,
write_num_frames=None,
compress=False,
compression_method=2):
if compress:
self.writer = kaldiio.WriteHelper(
wspecifier, compression_method=compression_method)
else:
self.writer = kaldiio.WriteHelper(wspecifier)
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
self.writer[key] = value
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value)}\n")
def parse_wspecifier(wspecifier: str) -> Dict[str, str]:
"""Parse wspecifier to dict
Examples:
>>> parse_wspecifier('ark,scp:out.ark,out.scp')
{'ark': 'out.ark', 'scp': 'out.scp'}
"""
ark_scp, filepath = wspecifier.split(":", 1)
if ark_scp not in ["ark", "scp,ark", "ark,scp"]:
raise ValueError("{} is not allowed: {}".format(ark_scp, wspecifier))
ark_scps = ark_scp.split(",")
filepaths = filepath.split(",")
if len(ark_scps) != len(filepaths):
raise ValueError("Mismatch: {} and {}".format(ark_scp, filepath))
spec_dict = dict(zip(ark_scps, filepaths))
return spec_dict
class HDF5Writer(BaseWriter):
"""HDF5Writer
Examples:
>>> with HDF5Writer('ark:out.h5', compress=True) as f:
... f['key'] = array
"""
def __init__(self, wspecifier, write_num_frames=None, compress=False):
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict["ark"]
if compress:
self.kwargs = {"compression": "gzip"}
else:
self.kwargs = {}
self.writer = h5py.File(spec_dict["ark"], "w")
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
self.writer.create_dataset(key, data=value, **self.kwargs)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value)}\n")
class SoundHDF5Writer(BaseWriter):
"""SoundHDF5Writer
Examples:
>>> fs = 16000
>>> with SoundHDF5Writer('ark:out.h5') as f:
... f['key'] = fs, array
"""
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict["ark"]
self.writer = SoundHDF5File(
spec_dict["ark"], "w", format=self.pcm_format)
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
# Change Tuple[int, ndarray] -> Tuple[ndarray, int]
# (scipy style -> soundfile style)
value = (value[1], value[0])
self.writer.create_dataset(key, data=value)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value[0])}\n")
class SoundWriter(BaseWriter):
"""SoundWriter
Examples:
>>> fs = 16000
>>> with SoundWriter('ark,scp:outdir,out.scp') as f:
... f['key'] = fs, array
"""
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
# e.g. ark,scp:dirname,wav.scp
# -> The wave files are found in dirname/*.wav
self.dirname = spec_dict["ark"]
Path(self.dirname).mkdir(parents=True, exist_ok=True)
self.writer = None
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
rate, signal = value
wavfile = Path(self.dirname) / (key + "." + self.pcm_format)
soundfile.write(wavfile, signal.astype(numpy.int16), rate)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {wavfile}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(signal)}\n")

@ -14,12 +14,12 @@
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
"""
from itertools import groupby
import editdistance
import numpy as np
__all__ = ['word_errors', 'char_errors', 'wer', 'cer']
editdistance.eval("a", "b")
__all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"]
def _levenshtein_distance(ref, hyp):
@ -211,3 +211,154 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False):
cer = float(edit_distance) / ref_len
return cer
class ErrorCalculator():
"""Calculate CER and WER for E2E_ASR and CTC models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list: List[str]
:param sym_space: <space>
:param sym_blank: <blank>
:return:
"""
def __init__(self,
char_list,
sym_space,
sym_blank,
report_cer=False,
report_wer=False):
"""Construct an ErrorCalculator object."""
super().__init__()
self.report_cer = report_cer
self.report_wer = report_wer
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
self.idx_blank = self.char_list.index(self.blank)
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad, is_ctc=False):
"""Calculate sentence-level WER/CER score.
:param paddle.Tensor ys_hat: prediction (batch, seqlen)
:param paddle.Tensor ys_pad: reference (batch, seqlen)
:param bool is_ctc: calculate CER score for CTC
:return: sentence-level WER score
:rtype float
:return: sentence-level CER score
:rtype float
"""
cer, wer = None, None
if is_ctc:
return self.calculate_cer_ctc(ys_hat, ys_pad)
elif not self.report_cer and not self.report_wer:
return cer, wer
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
def calculate_cer_ctc(self, ys_hat, ys_pad):
"""Calculate sentence-level CER score for CTC.
:param paddle.Tensor ys_hat: prediction (batch, seqlen)
:param paddle.Tensor ys_pad: reference (batch, seqlen)
:return: average sentence-level CER score
:rtype float
"""
cers, char_ref_lens = [], []
for i, y in enumerate(ys_hat):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat, seq_true = [], []
for idx in y_hat:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_hat.append(self.char_list[int(idx)])
for idx in y_true:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_true.append(self.char_list[int(idx)])
hyp_chars = "".join(seq_hat)
ref_chars = "".join(seq_true)
if len(ref_chars) > 0:
cers.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
return cer_ctc
def convert_to_char(self, ys_hat, ys_pad):
"""Convert index to character.
:param paddle.Tensor seqs_hat: prediction (batch, seqlen)
:param paddle.Tensor seqs_true: reference (batch, seqlen)
:return: token list of prediction
:rtype list
:return: token list of reference
:rtype list
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
def calculate_cer(self, seqs_hat, seqs_true):
"""Calculate sentence-level CER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level CER score
:rtype float
"""
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
def calculate_wer(self, seqs_hat, seqs_true):
"""Calculate sentence-level WER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level WER score
:rtype float
"""
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)

@ -0,0 +1,4 @@
dump
fbank
exp
data

@ -1,8 +1,11 @@
# LibriSpeech
| Model | Params | Config | Augmentation| Loss |
| --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 |
## Transformer
| Model | Params | GPUS | Averaged Model | Config | Augmentation| Loss |
| --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | 8 Tesla V100-SXM2-32GB | 10-best val_loss | conf/transformer.yaml | spec_aug | 6.3197922706604 |
| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
@ -11,4 +14,14 @@
| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 |
| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 |
| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 |
### JoinCTC
| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| test-clean | join_ctc_only_att | 2620 | 52576 | 96.1 | 2.5 | 1.4 | 0.4 | 4.4 | 34.7 |
| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 |
| test-clean | join_ctc_w_lm | 2620 | 52576 | 97.9 | 1.8 | 0.2 | 0.3 | 2.4 | 27.8 |
Compare with [ESPNET](https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/RESULTS.md#pytorch-large-transformer-with-specaug-4-gpus--transformer-lm-4-gpus)
we using 8gpu, but model size (aheads4-adim256) small than it.

@ -1,122 +0,0 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: True
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 240
accum_grad: 8
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: true # simulate streaming inference. Defaults to False.

@ -1,115 +0,0 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: transformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 1
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: true # simulate streaming inference. Defaults to False.

@ -1,118 +0,0 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.5 # seconds
max_input_len: 20.0 # seconds
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 8
global_grad_clip: 3.0
optim: adam
optim_conf:
lr: 0.004
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -0,0 +1,2 @@
--sample-frequency=16000
--num-mel-bins=80

@ -0,0 +1,13 @@
model_module: transformer
model:
n_vocab: 5002
pos_enc: null
embed_unit: 128
att_unit: 512
head: 8
unit: 2048
layer: 16
dropout_rate: 0.5
emb_dropout_rate: 0.0
att_dropout_rate: 0.0
tie_weights: False

@ -0,0 +1 @@
--sample-frequency=16000

@ -5,9 +5,9 @@ data:
test_manifest: data/manifest.test-clean
collator:
vocab_filepath: data/bpe_unigram_5000_units.txt
vocab_filepath: data/lang_char/train_960_unigram5000_units.txt
unit_type: spm
spm_model_prefix: data/bpe_unigram_5000
spm_model_prefix: data/lang_char/train_960_unigram5000
feat_dim: 83
stride_ms: 10.0
window_ms: 25.0

@ -2,19 +2,42 @@
stage=-1
stop_stage=100
nj=32
debugmode=1
dumpdir=dump # directory to dump full features
N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
verbose=0 # verbose option
resume= # Resume the training from snapshot
# feature configuration
do_delta=false
# Set this to somewhere where you want to put your data, or where
# someone else has already put it. You'll want to change this
# if you're not on the CLSP grid.
datadir=${MAIN_ROOT}/examples/dataset/
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
source ${MAIN_ROOT}/utils/parse_options.sh
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train_960
train_sp=train_sp
train_dev=dev
recog_set="test_clean test_other dev_clean dev_other"
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
@ -46,63 +69,98 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--num_samples=-1 \
--spectrum_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--use_dB_normalization=False \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
### Task dependent. You have to make data the following preparation part by yourself.
### But you can utilize Kaldi recipes in most cases
echo "stage 0: Data preparation"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
# use underscore-separated names in data directories.
local/data_prep.sh ${datadir}/librispeech/${part}/LibriSpeech/${part} data/${part//-/_}
done
fi
feat_tr_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${feat_tr_dir}
feat_sp_dir=${dumpdir}/${train_sp}/delta${do_delta}; mkdir -p ${feat_sp_dir}
feat_dt_dir=${dumpdir}/${train_dev}/delta${do_delta}; mkdir -p ${feat_dt_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type "spm" \
--spm_vocab_size=${nbpe} \
--spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_paths="data/manifest.train.raw"
### Task dependent. You have to design training and dev sets by yourself.
### But you can utilize Kaldi recipes in most cases
echo "stage 1: Feature Generation"
fbankdir=fbank
# Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame
for x in dev_clean test_clean dev_other test_other train_clean_100 train_clean_360 train_other_500; do
steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj ${nj} --write_utt2num_frames true \
data/${x} exp/make_fbank/${x} ${fbankdir}
utils/fix_data_dir.sh data/${x}
done
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
utils/combine_data.sh --extra_files utt2num_frames data/${train_set}_org data/train_clean_100 data/train_clean_360 data/train_other_500
utils/combine_data.sh --extra_files utt2num_frames data/${train_dev}_org data/dev_clean data/dev_other
utils/perturb_data_dir_speed.sh 0.9 data/${train_set}_org data/temp1
utils/perturb_data_dir_speed.sh 1.0 data/${train_set}_org data/temp2
utils/perturb_data_dir_speed.sh 1.1 data/${train_set}_org data/temp3
utils/combine_data.sh --extra-files utt2uniq data/${train_sp}_org data/temp1 data/temp2 data/temp3
# remove utt having more than 3000 frames
# remove utt having more than 400 characters
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_set}_org data/${train_set}
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_sp}_org data/${train_sp}
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_dev}_org data/${train_dev}
steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj $nj --write_utt2num_frames true \
data/train_sp exp/make_fbank/train_sp ${fbankdir}
utils/fix_data_dir.sh data/train_sp
# compute global CMVN
compute-cmvn-stats scp:data/${train_sp}/feats.scp data/${train_sp}/cmvn.ark
# dump features for training
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${train_sp}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/train ${feat_sp_dir}
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${train_dev}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/dev ${feat_dt_dir}
for rtask in ${recog_set}; do
feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir}
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${rtask}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/recog/${rtask} \
${feat_recog_dir}
done
fi
dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
bpemodel=data/lang_char/${train_set}_${bpemode}${nbpe}
echo "dictionary: ${dict}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "spm" \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
}&
### Task dependent. You have to check non-linguistic symbols used in the corpus.
echo "stage 2: Dictionary and Json Data Preparation"
mkdir -p data/lang_char/
echo "<unk> 1" > ${dict} # <unk> must be 1, 0 will be used for "blank" in CTC
cut -f 2- -d" " data/${train_set}/text > data/lang_char/input.txt
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+1}' >> ${dict}
wc -l ${dict}
# make json labels
data2json.sh --nj ${nj} --feat ${feat_sp_dir}/feats.scp --bpecode ${bpemodel}.model \
data/${train_sp} ${dict} > ${feat_sp_dir}/data_${bpemode}${nbpe}.json
data2json.sh --nj ${nj} --feat ${feat_dt_dir}/feats.scp --bpecode ${bpemodel}.model \
data/${train_dev} ${dict} > ${feat_dt_dir}/data_${bpemode}${nbpe}.json
for rtask in ${recog_set}; do
feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
data2json.sh --nj ${nj} --feat ${feat_recog_dir}/feats.scp --bpecode ${bpemodel}.model \
data/${rtask} ${dict} > ${feat_recog_dir}/data_${bpemode}${nbpe}.json
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# make json labels
python3 local/espnet_json_to_manifest.py --json-file ${feat_sp_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.train
python3 local/espnet_json_to_manifest.py --json-file ${feat_dt_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.dev
for rtask in ${recog_set}; do
feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
python3 local/espnet_json_to_manifest.py --json-file ${feat_recog_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.${rtask//_/-}
done
wait
fi
echo "LibriSpeech Data preparation done."

@ -0,0 +1,85 @@
#!/usr/bin/env bash
# Copyright 2014 Vassil Panayotov
# 2014 Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <src-dir> <dst-dir>"
echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean"
exit 1
fi
src=$1
dst=$2
# all utterances are FLAC compressed
if ! which flac >&/dev/null; then
echo "Please install 'flac' on ALL worker nodes!"
exit 1
fi
spk_file=$src/../SPEAKERS.TXT
mkdir -p $dst || exit 1
[ ! -d $src ] && echo "$0: no such directory $src" && exit 1
[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1
wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp
trans=$dst/text; [[ -f "$trans" ]] && rm $trans
utt2spk=$dst/utt2spk; [[ -f "$utt2spk" ]] && rm $utt2spk
spk2gender=$dst/spk2gender; [[ -f $spk2gender ]] && rm $spk2gender
for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do
reader=$(basename $reader_dir)
if ! [ $reader -eq $reader ]; then # not integer.
echo "$0: unexpected subdirectory name $reader"
exit 1
fi
reader_gender=$(egrep "^$reader[ ]+\|" $spk_file | awk -F'|' '{gsub(/[ ]+/, ""); print tolower($2)}')
if [ "$reader_gender" != 'm' ] && [ "$reader_gender" != 'f' ]; then
echo "Unexpected gender: '$reader_gender'"
exit 1
fi
for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do
chapter=$(basename $chapter_dir)
if ! [ "$chapter" -eq "$chapter" ]; then
echo "$0: unexpected chapter-subdirectory name $chapter"
exit 1
fi
find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \
awk -v "dir=$chapter_dir" '{printf "%s flac -c -d -s %s/%s.flac |\n", $0, dir, $0}' >>$wav_scp|| exit 1
chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt
[ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1
cat $chapter_trans >>$trans
# NOTE: For now we are using per-chapter utt2spk. That is each chapter is considered
# to be a different speaker. This is done for simplicity and because we want
# e.g. the CMVN to be calculated per-chapter
awk -v "reader=$reader" -v "chapter=$chapter" '{printf "%s %s-%s\n", $1, reader, chapter}' \
<$chapter_trans >>$utt2spk || exit 1
# reader -> gender map (again using per-chapter granularity)
echo "${reader}-${chapter} $reader_gender" >>$spk2gender
done
done
spk2utt=$dst/spk2utt
utils/utt2spk_to_spk2utt.pl <$utt2spk >$spk2utt || exit 1
ntrans=$(wc -l <$trans)
nutt2spk=$(wc -l <$utt2spk)
! [ "$ntrans" -eq "$nutt2spk" ] && \
echo "Inconsistent #transcripts($ntrans) and #utt2spk($nutt2spk)" && exit 1
utils/validate_data_dir.sh --no-feats $dst || exit 1
echo "$0: successfully prepared data in $dst"
exit 0

@ -11,22 +11,24 @@ tag=
decode_config=conf/decode/decode.yaml
# lm params
lang_model=rnnlm.model.best
lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/
lmtag='nolm'
lang_model=transformerLM.pdparams
lmexpdir=exp/lm/transformer
rnnlm_config_path=conf/lm/transformer.yaml
lmtag='transformer'
train_set=train_960
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpeprefix=data/lang_char/${train_set}_${bpemode}${nbpe}
bpemodel=${bpeprefix}.model
# bin params
config_path=conf/transformer.yaml
dict=data/bpe_unigram_5000_units.txt
dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
ckpt_prefix=
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -90,9 +92,9 @@ for dmethd in join_ctc; do
--recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \
--result-label ${decode_dir}/data.JOB.json \
--model-conf ${config_path} \
--model ${ckpt_prefix}.pdparams
#--rnnlm ${lmexpdir}/${lang_model} \
--model ${ckpt_prefix}.pdparams \
--rnnlm-conf ${rnnlm_config_path} \
--rnnlm ${lmexpdir}/${lang_model}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}

@ -8,17 +8,18 @@ nj=32
lmtag='nolm'
train_set=train_960
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpeprefix=data/lang_char/${train_set}_${bpemode}${nbpe}
bpemodel=${bpeprefix}.model
config_path=conf/transformer.yaml
dict=data/bpe_unigram_5000_units.txt
dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
ckpt_prefix=
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;

@ -1,6 +1,6 @@
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${MAIN_ROOT}/utils:${PWD}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
@ -13,3 +13,16 @@ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2_kaldi
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
# srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
export SRILM=${MAIN_ROOT}/tools/srilm
export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64
# Kaldi
export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!"
[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh

@ -33,16 +33,24 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
# attetion resocre decoder
./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] && ${use_lm} == true; then
# join ctc decoder, use transformerlm to score
if [ ! -f exp/lm/transformer/transformerLM.pdparams ]; then
wget https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams exp/lm/transformer/
fi
bash local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt}
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi

@ -0,0 +1 @@
../../../tools/kaldi/egs/wsj/s5/steps/

@ -1 +1 @@
../../../utils/
../../../tools/kaldi/egs/wsj/s5/utils

@ -23,6 +23,7 @@ praatio~=4.1
pre-commit
pybind11
pypinyin
python-dateutil
pyworld
resampy==0.2.2
sacrebleu
@ -41,3 +42,4 @@ visualdl==2.2.0
webrtcvad
yacs
yq
nara_wpe

@ -65,13 +65,6 @@ def _remove(files: str):
def _post_install(install_lib_dir):
# apt
check_call("apt-get update -y")
check_call("apt-get install -y " + 'vim tig tree sox pkg-config ' +
'libsndfile1 libflac-dev libogg-dev ' +
'libvorbis-dev libboost-dev swig python3-dev ')
print("apt install.")
# tools/make
tool_dir = HERE / "tools"
_remove(tool_dir.glob("*.done"))

@ -10,7 +10,7 @@ fi
if [ -e /etc/lsb-release ];then
${SUDO} apt-get update -y
${SUDO} apt-get install -y jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
${SUDO} apt-get install -y bc jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
if [ $? != 0 ]; then
error_msg "Please using Ubuntu or install pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev by user."
exit -1

@ -10,7 +10,7 @@ WGET ?= wget --no-check-certificate
.PHONY: all clean
all: virtualenv.done kenlm.done sox.done soxbindings.done mfa.done sclite.done
all: virtualenv.done apt.done kenlm.done sox.done soxbindings.done mfa.done sclite.done
virtualenv.done:
test -d venv || virtualenv -p $(PYTHON) venv
@ -21,6 +21,13 @@ clean:
find -iname "*.pyc" -delete
rm -rf kenlm
apt.done:
apt update -y
apt install -y bc flac jq vim tig tree pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
echo "check_certificate = off" >> ~/.wgetrc
touch apt.done
kenlm.done:
# Ubuntu 16.04 透過 apt 會安裝 boost 1.58.0
# it seems that boost (1.54.0) requires higher version. After I switched to g++-5 it compiles normally.
@ -48,6 +55,13 @@ mfa.done:
tar xvf montreal-forced-aligner_linux.tar.gz
touch mfa.done
openblas.done:
bash extras/install_openblas.sh
touch openblas.done
kaldi.done: openblas.done
bash extras/install_kaldi.sh
touch kaldi.done
#== SCTK ===============================================================================
# SCTK official repo does not have version tags. Here's the mapping:

@ -16,7 +16,7 @@ else
echo "$KALDI_DIR already exists!"
fi
cd "$KALDI_DIR/tools"
pushd "$KALDI_DIR/tools"
git pull
# Prevent kaldi from switching default python version
@ -28,8 +28,12 @@ touch "python/.use_default_python"
make -j4
pushd ../src
./configure --shared --use-cuda=no --static-math --mathlib=OPENBLAS --openblas-root=${KALDI_DIR}/../OpenBLAS/install
OPENBLAS_DIR=${KALDI_DIR}/../OpenBLAS
mkdir -p ${OPENBLAS_DIR}/install
./configure --shared --use-cuda=no --static-math --mathlib=OPENBLAS --openblas-root=${OPENBLAS_DIR}/install
make clean -j && make depend -j && make -j4
popd
popd
echo "Done installing Kaldi."

@ -0,0 +1,149 @@
#!/usr/bin/env python3
import argparse
import logging
from distutils.util import strtobool
import kaldiio
import numpy
from deepspeech.transform.cmvn import CMVN
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="apply mean-variance normalization to files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--stats-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "npy"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--norm-means",
type=strtobool,
default=True,
help="Do variance normalization or not.", )
parser.add_argument(
"--norm-vars",
type=strtobool,
default=False,
help="Do variance normalization or not.", )
parser.add_argument(
"--reverse",
type=strtobool,
default=False,
help="Do reverse mode or not")
parser.add_argument(
"--spk2utt",
type=str,
help="A text file of speaker to utterance-list map. "
"(Don't give rspecifier format, such as "
'"ark:spk2utt")', )
parser.add_argument(
"--utt2spk",
type=str,
help="A text file of utterance to speaker map. "
"(Don't give rspecifier format, such as "
'"ark:utt2spk")', )
parser.add_argument(
"--write-num-frames",
type=str,
help="Specify wspecifer for utt2num_frames")
parser.add_argument(
"--compress",
type=strtobool,
default=False,
help="Save in compressed format")
parser.add_argument(
"--compression-method",
type=int,
default=2,
help="Specify the method(if mat) or "
"gzip-level(if hdf5)", )
parser.add_argument(
"stats_rspecifier_or_rxfilename",
help="Input stats. e.g. ark:stats.ark or stats.mat", )
parser.add_argument(
"rspecifier", type=str, help="Read specifier id. e.g. ark:some.ark")
parser.add_argument(
"wspecifier", type=str, help="Write specifier id. e.g. ark:some.ark")
return parser
def main():
args = get_parser().parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
if ":" in args.stats_rspecifier_or_rxfilename:
is_rspcifier = True
if args.stats_filetype == "npy":
stats_filetype = "hdf5"
else:
stats_filetype = args.stats_filetype
stats_dict = dict(
file_reader_helper(args.stats_rspecifier_or_rxfilename,
stats_filetype))
else:
is_rspcifier = False
if args.stats_filetype == "mat":
stats = kaldiio.load_mat(args.stats_rspecifier_or_rxfilename)
else:
stats = numpy.load(args.stats_rspecifier_or_rxfilename)
stats_dict = {None: stats}
cmvn = CMVN(
stats=stats_dict,
norm_means=args.norm_means,
norm_vars=args.norm_vars,
utt2spk=args.utt2spk,
spk2utt=args.spk2utt,
reverse=args.reverse, )
with file_writer_helper(
args.wspecifier,
filetype=args.out_filetype,
write_num_frames=args.write_num_frames,
compress=args.compress,
compression_method=args.compression_method, ) as writer:
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]
rate, mat = mat
mat = cmvn(mat, utt if is_rspcifier else None)
writer[utt] = mat
if __name__ == "__main__":
main()

@ -47,8 +47,10 @@ def main(args):
beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
avg_val_score = np.mean(beat_val_scores)
print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
print("averaged val score = " + str(avg_val_score))
path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
@ -80,7 +82,7 @@ def main(args):
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"val_loss_mean": np.mean(beat_val_scores),
"val_loss_mean": avg_val_score,
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(),

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2021 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import glob
import os
from dateutil import parser
def get_parser():
parser = argparse.ArgumentParser(
description="calculate real time factor (RTF)")
parser.add_argument(
"--log-dir",
type=str,
default=None,
help="path to logging directory", )
return parser
def main():
args = get_parser().parse_args()
audio_sec = 0
decode_sec = 0
n_utt = 0
audio_durations = []
start_times = []
end_times = []
for x in glob.glob(os.path.join(args.log_dir, "decode.*.log")):
with codecs.open(x, "r", "utf-8") as f:
for line in f:
x = line.strip()
# 2021-10-25 08:22:04.052 | INFO | xxx:recog_v2:188 - feat: (1570, 83)
if "feat:" in x:
dur = int(x.split("(")[1].split(',')[0])
audio_durations += [dur]
start_times += [parser.parse(x.split("|")[0])]
elif "total log probability:" in x:
end_times += [parser.parse(x.split("|")[0])]
assert len(audio_durations) == len(end_times), (len(audio_durations),
len(end_times), )
assert len(start_times) == len(end_times), (len(start_times),
len(end_times))
audio_sec += sum(audio_durations) / 100 # [sec]
decode_sec += sum([(end - start).total_seconds()
for start, end in zip(start_times, end_times)])
n_utt += len(audio_durations)
print("Total audio duration: %.3f [sec]" % audio_sec)
print("Total decoding time: %.3f [sec]" % decode_sec)
rtf = decode_sec / audio_sec if audio_sec > 0 else 0
print("RTF: %.3f" % rtf)
latency = decode_sec * 1000 / n_utt if n_utt > 0 else 0
print("Latency: %.3f [ms/sentence]" % latency)
if __name__ == "__main__":
main()

@ -0,0 +1,186 @@
#!/usr/bin/env python3
import argparse
import logging
import kaldiio
import numpy as np
from deepspeech.transform.transformation import Transformation
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="Compute cepstral mean and "
"variance normalization statistics"
"If wspecifier provided: per-utterance by default, "
"or per-speaker if"
"spk2utt option provided; if wxfilename: global",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--spk2utt",
type=str,
help="A text file of speaker to utterance-list map. "
"(Don't give rspecifier format, such as "
'"ark:utt2spk")', )
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "npy"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing", )
parser.add_argument(
"rspecifier",
type=str,
help="Read specifier for feats. e.g. ark:some.ark")
parser.add_argument(
"wspecifier_or_wxfilename",
type=str,
help="Write specifier. e.g. ark:some.ark")
return parser
def main():
args = get_parser().parse_args()
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
is_wspecifier = ":" in args.wspecifier_or_wxfilename
if is_wspecifier:
if args.spk2utt is not None:
logging.info("Performing as speaker CMVN mode")
utt2spk_dict = {}
with open(args.spk2utt) as f:
for line in f:
spk, utts = line.rstrip().split(None, 1)
for utt in utts.split():
utt2spk_dict[utt] = spk
def utt2spk(x):
return utt2spk_dict[x]
else:
logging.info("Performing as utterance CMVN mode")
def utt2spk(x):
return x
if args.out_filetype == "npy":
logging.warning("--out-filetype npy is allowed only for "
"Global CMVN mode, changing to hdf5")
args.out_filetype = "hdf5"
else:
logging.info("Performing as global CMVN mode")
if args.spk2utt is not None:
logging.warning("spk2utt is not used for global CMVN mode")
def utt2spk(x):
return None
if args.out_filetype == "hdf5":
logging.warning("--out-filetype hdf5 is not allowed for "
"Global CMVN mode, changing to npy")
args.out_filetype = "npy"
if args.preprocess_conf is not None:
preprocessing = Transformation(args.preprocess_conf)
logging.info("Apply preprocessing: {}".format(preprocessing))
else:
preprocessing = None
# Calculate stats for each speaker
counts = {}
sum_feats = {}
square_sum_feats = {}
idx = 0
for idx, (utt, matrix) in enumerate(
file_reader_helper(args.rspecifier, args.in_filetype), 1):
if is_scipy_wav_style(matrix):
# If data is sound file, then got as Tuple[int, ndarray]
rate, matrix = matrix
if preprocessing is not None:
matrix = preprocessing(matrix, uttid_list=utt)
spk = utt2spk(utt)
# Init at the first seen of the spk
if spk not in counts:
counts[spk] = 0
feat_shape = matrix.shape[1:]
# Accumulate in double precision
sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
square_sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
counts[spk] += matrix.shape[0]
sum_feats[spk] += matrix.sum(axis=0)
square_sum_feats[spk] += (matrix**2).sum(axis=0)
logging.info("Processed {} utterances".format(idx))
assert idx > 0, idx
cmvn_stats = {}
for spk in counts:
feat_shape = sum_feats[spk].shape
cmvn_shape = (2, feat_shape[0] + 1) + feat_shape[1:]
_cmvn_stats = np.empty(cmvn_shape, dtype=np.float64)
_cmvn_stats[0, :-1] = sum_feats[spk]
_cmvn_stats[1, :-1] = square_sum_feats[spk]
_cmvn_stats[0, -1] = counts[spk]
_cmvn_stats[1, -1] = 0.0
# You can get the mean and std as following,
# >>> N = _cmvn_stats[0, -1]
# >>> mean = _cmvn_stats[0, :-1] / N
# >>> std = np.sqrt(_cmvn_stats[1, :-1] / N - mean ** 2)
cmvn_stats[spk] = _cmvn_stats
# Per utterance or speaker CMVN
if is_wspecifier:
with file_writer_helper(
args.wspecifier_or_wxfilename,
filetype=args.out_filetype) as writer:
for spk, mat in cmvn_stats.items():
writer[spk] = mat
# Global CMVN
else:
matrix = cmvn_stats[None]
if args.out_filetype == "npy":
np.save(args.wspecifier_or_wxfilename, matrix)
elif args.out_filetype == "mat":
# Kaldi supports only matrix or vector
kaldiio.save_mat(args.wspecifier_or_wxfilename, matrix)
else:
raise RuntimeError(
"Not supporting: --out-filetype {}".format(args.out_filetype))
if __name__ == "__main__":
main()

@ -0,0 +1,104 @@
#!/usr/bin/env python3
import argparse
import logging
from distutils.util import strtobool
from deepspeech.transform.transformation import Transformation
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="copy feature with preprocessing",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--write-num-frames",
type=str,
help="Specify wspecifer for utt2num_frames")
parser.add_argument(
"--compress",
type=strtobool,
default=False,
help="Save in compressed format")
parser.add_argument(
"--compression-method",
type=int,
default=2,
help="Specify the method(if mat) or "
"gzip-level(if hdf5)", )
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing", )
parser.add_argument(
"rspecifier",
type=str,
help="Read specifier for feats. e.g. ark:some.ark")
parser.add_argument(
"wspecifier", type=str, help="Write specifier. e.g. ark:some.ark")
return parser
def main():
parser = get_parser()
args = parser.parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
if args.preprocess_conf is not None:
preprocessing = Transformation(args.preprocess_conf)
logging.info("Apply preprocessing: {}".format(preprocessing))
else:
preprocessing = None
with file_writer_helper(
args.wspecifier,
filetype=args.out_filetype,
write_num_frames=args.write_num_frames,
compress=args.compress,
compression_method=args.compression_method, ) as writer:
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]
rate, mat = mat
if preprocessing is not None:
mat = preprocessing(mat, uttid_list=utt)
# shape = (Time, Channel)
if args.out_filetype in ["sound.hdf5", "sound"]:
# Write Tuple[int, numpy.ndarray] (scipy style)
writer[utt] = (rate, mat)
else:
writer[utt] = mat
if __name__ == "__main__":
main()

@ -0,0 +1,170 @@
#!/usr/bin/env bash
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
echo "$0 $*" >&2 # Print the command line for logging
. ./path.sh
nj=1
cmd=run.pl
nlsyms=""
lang=""
feat="" # feat.scp
oov="<unk>"
bpecode=""
allow_one_column=false
verbose=0
trans_type=char
filetype=""
preprocess_conf=""
category=""
out="" # If omitted, write in stdout
text=""
multilingual=false
help_message=$(cat << EOF
Usage: $0 <data-dir> <dict>
e.g. $0 data/train data/lang_1char/train_units.txt
Options:
--nj <nj> # number of parallel jobs
--cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs.
--feat <feat-scp> # feat.scp or feat1.scp,feat2.scp,...
--oov <oov-word> # Default: <unk>
--out <outputfile> # If omitted, write in stdout
--filetype <mat|hdf5|sound.hdf5> # Specify the format of feats file
--preprocess-conf <json> # Apply preprocess to feats when creating shape.scp
--verbose <num> # Default: 0
EOF
)
. utils/parse_options.sh
if [ $# != 2 ]; then
echo "${help_message}" 1>&2
exit 1;
fi
set -euo pipefail
dir=$1
dic=$2
tmpdir=$(mktemp -d ${dir}/tmp-XXXXX)
trap 'rm -rf ${tmpdir}' EXIT
if [ -z ${text} ]; then
text=${dir}/text
fi
# 1. Create scp files for inputs
# These are not necessary for decoding mode, and make it as an option
input=
if [ -n "${feat}" ]; then
_feat_scps=$(echo "${feat}" | tr ',' ' ' )
read -r -a feat_scps <<< $_feat_scps
num_feats=${#feat_scps[@]}
for (( i=1; i<=num_feats; i++ )); do
feat=${feat_scps[$((i-1))]}
mkdir -p ${tmpdir}/input_${i}
input+="input_${i} "
cat ${feat} > ${tmpdir}/input_${i}/feat.scp
# Dump in the "legacy" style JSON format
if [ -n "${filetype}" ]; then
awk -v filetype=${filetype} '{print $1 " " filetype}' ${feat} \
> ${tmpdir}/input_${i}/filetype.scp
fi
feat_to_shape.sh --cmd "${cmd}" --nj ${nj} \
--filetype "${filetype}" \
--preprocess-conf "${preprocess_conf}" \
--verbose ${verbose} ${feat} ${tmpdir}/input_${i}/shape.scp
done
fi
# 2. Create scp files for outputs
mkdir -p ${tmpdir}/output
if [ -n "${bpecode}" ]; then
if [ ${multilingual} = true ]; then
# remove a space before the language ID
paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \
| spm_encode --model=${bpecode} --output_format=piece | cut -f 2- -d" ") \
> ${tmpdir}/output/token.scp
else
paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \
| spm_encode --model=${bpecode} --output_format=piece) \
> ${tmpdir}/output/token.scp
fi
elif [ -n "${nlsyms}" ]; then
text2token.py -s 1 -n 1 -l ${nlsyms} ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp
else
text2token.py -s 1 -n 1 ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp
fi
< ${tmpdir}/output/token.scp utils/sym2int.pl --map-oov ${oov} -f 2- ${dic} > ${tmpdir}/output/tokenid.scp
# +2 comes from CTC blank and EOS
vocsize=$(tail -n 1 ${dic} | awk '{print $2}')
odim=$(echo "$vocsize + 2" | bc)
< ${tmpdir}/output/tokenid.scp awk -v odim=${odim} '{print $1 " " NF-1 "," odim}' > ${tmpdir}/output/shape.scp
cat ${text} > ${tmpdir}/output/text.scp
# 3. Create scp files for the others
mkdir -p ${tmpdir}/other
if [ ${multilingual} == true ]; then
awk '{
n = split($1,S,"[-]");
lang=S[n];
print $1 " " lang
}' ${text} > ${tmpdir}/other/lang.scp
elif [ -n "${lang}" ]; then
awk -v lang=${lang} '{print $1 " " lang}' ${text} > ${tmpdir}/other/lang.scp
fi
if [ -n "${category}" ]; then
awk -v category=${category} '{print $1 " " category}' ${dir}/text \
> ${tmpdir}/other/category.scp
fi
cat ${dir}/utt2spk > ${tmpdir}/other/utt2spk.scp
# 4. Merge scp files into a JSON file
opts=""
if [ -n "${feat}" ]; then
intypes="${input} output other"
else
intypes="output other"
fi
for intype in ${intypes}; do
if [ -z "$(find "${tmpdir}/${intype}" -name "*.scp")" ]; then
continue
fi
if [ ${intype} != other ]; then
opts+="--${intype%_*}-scps "
else
opts+="--scps "
fi
for x in "${tmpdir}/${intype}"/*.scp; do
k=$(basename ${x} .scp)
if [ ${k} = shape ]; then
opts+="shape:${x}:shape "
else
opts+="${k}:${x} "
fi
done
done
if ${allow_one_column}; then
opts+="--allow-one-column true "
else
opts+="--allow-one-column false "
fi
if [ -n "${out}" ]; then
opts+="-O ${out}"
fi
merge_scp2json.py --verbose ${verbose} ${opts}
rm -fr ${tmpdir}

@ -0,0 +1,95 @@
#!/usr/bin/env bash
# Copyright 2017 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
echo "$0 $*" # Print the command line for logging
. ./path.sh
cmd=run.pl
do_delta=false
nj=1
verbose=0
compress=true
write_utt2num_frames=true
filetype='mat' # mat or hdf5
help_message="Usage: $0 <scp> <cmvnark> <logdir> <dumpdir>"
. utils/parse_options.sh
scp=$1
cvmnark=$2
logdir=$3
dumpdir=$4
if [ $# != 4 ]; then
echo "${help_message}"
exit 1;
fi
set -euo pipefail
mkdir -p ${logdir}
mkdir -p ${dumpdir}
dumpdir=$(perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' ${dumpdir} ${PWD})
for n in $(seq ${nj}); do
# the next command does nothing unless $dumpdir/storage/ exists, see
# utils/create_data_link.pl for more info.
utils/create_data_link.pl ${dumpdir}/feats.${n}.ark
done
if ${write_utt2num_frames}; then
write_num_frames_opt="--write-num-frames=ark,t:$dumpdir/utt2num_frames.JOB"
else
write_num_frames_opt=
fi
# split scp file
split_scps=""
for n in $(seq ${nj}); do
split_scps="$split_scps $logdir/feats.$n.scp"
done
utils/split_scp.pl ${scp} ${split_scps} || exit 1;
# dump features
if ${do_delta}; then
${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \
apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \
add-deltas ark:- ark:- \| \
copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \
--compress=${compress} --compression-method=2 ${write_num_frames_opt} \
ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \
|| exit 1
else
${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \
apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \
copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \
--compress=${compress} --compression-method=2 ${write_num_frames_opt} \
ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \
|| exit 1
fi
# concatenate scp files
for n in $(seq ${nj}); do
cat ${dumpdir}/feats.${n}.scp || exit 1;
done > ${dumpdir}/feats.scp || exit 1
if ${write_utt2num_frames}; then
for n in $(seq ${nj}); do
cat ${dumpdir}/utt2num_frames.${n} || exit 1;
done > ${dumpdir}/utt2num_frames || exit 1
rm ${dumpdir}/utt2num_frames.* 2>/dev/null
fi
# Write the filetype, this will be used for data2json.sh
echo ${filetype} > ${dumpdir}/filetype
# remove temp scps
rm ${logdir}/feats.*.scp 2>/dev/null
if [ ${verbose} -eq 1 ]; then
echo "Succeeded dumping features for training"
fi

@ -0,0 +1,84 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
from deepspeech.transform.transformation import Transformation
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
def get_parser():
parser = argparse.ArgumentParser(
description="convert feature to its shape",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
)
parser.add_argument(
"out",
nargs="?",
type=argparse.FileType("w"),
default=sys.stdout,
help="The output filename. " "If omitted, then output to sys.stdout",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
if args.preprocess_conf is not None:
preprocessing = Transformation(args.preprocess_conf)
logging.info("Apply preprocessing: {}".format(preprocessing))
else:
preprocessing = None
# There are no necessary for matrix without preprocessing,
# so change to file_reader_helper to return shape.
# This make sense only with filetype="hdf5".
for utt, mat in file_reader_helper(
args.rspecifier, args.filetype, return_shape=preprocessing is None
):
if preprocessing is not None:
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]
rate, mat = mat
mat = preprocessing(mat, uttid_list=utt)
shape_str = ",".join(map(str, mat.shape))
else:
if len(mat) == 2 and isinstance(mat[1], tuple):
# If data is sound file, Tuple[int, Tuple[int, ...]]
rate, mat = mat
shape_str = ",".join(map(str, mat))
args.out.write("{} {}\n".format(utt, shape_str))
if __name__ == "__main__":
main()

@ -0,0 +1,72 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=4
cmd=run.pl
verbose=0
filetype=""
preprocess_conf=""
# End configuration section.
help_message=$(cat << EOF
Usage: $0 [options] <input-scp> <output-scp> [<log-dir>]
e.g.: $0 data/train/feats.scp data/train/shape.scp data/train/log
Options:
--nj <nj> # number of parallel jobs
--cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs.
--filetype <mat|hdf5|sound.hdf5> # Specify the format of feats file
--preprocess-conf <json> # Apply preprocess to feats when creating shape.scp
--verbose <num> # Default: 0
EOF
)
echo "$0 $*" 1>&2 # Print the command line for logging
. parse_options.sh || exit 1;
if [ $# -lt 2 ] || [ $# -gt 3 ]; then
echo "${help_message}" 1>&2
exit 1;
fi
set -euo pipefail
scp=$1
outscp=$2
data=$(dirname ${scp})
if [ $# -eq 3 ]; then
logdir=$3
else
logdir=${data}/log
fi
mkdir -p ${logdir}
nj=$((nj<$(<"${scp}" wc -l)?nj:$(<"${scp}" wc -l)))
split_scps=""
for n in $(seq ${nj}); do
split_scps="${split_scps} ${logdir}/feats.${n}.scp"
done
utils/split_scp.pl ${scp} ${split_scps}
if [ -n "${preprocess_conf}" ]; then
preprocess_opt="--preprocess-conf ${preprocess_conf}"
else
preprocess_opt=""
fi
if [ -n "${filetype}" ]; then
filetype_opt="--filetype ${filetype}"
else
filetype_opt=""
fi
${cmd} JOB=1:${nj} ${logdir}/feat_to_shape.JOB.log \
feat-to-shape.py --verbose ${verbose} ${preprocess_opt} ${filetype_opt} \
scp:${logdir}/feats.JOB.scp ${logdir}/shape.JOB.scp
# concatenate the .scp files together.
for n in $(seq ${nj}); do
cat ${logdir}/shape.${n}.scp
done > ${outscp}
rm -f ${logdir}/feats.*.scp 2>/dev/null

@ -0,0 +1,289 @@
#!/usr/bin/env python3
# encoding: utf-8
import argparse
import codecs
import json
import logging
import sys
from distutils.util import strtobool
from io import open
from deepspeech.utils.cli_utils import get_commandline_args
PY2 = sys.version_info[0] == 2
sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout if PY2 else sys.stdout.buffer)
# Special types:
def shape(x):
"""Change str to List[int]
>>> shape('3,5')
[3, 5]
>>> shape(' [3, 5] ')
[3, 5]
"""
# x: ' [3, 5] ' -> '3, 5'
x = x.strip()
if x[0] == "[":
x = x[1:]
if x[-1] == "]":
x = x[:-1]
return list(map(int, x.split(",")))
def get_parser():
parser = argparse.ArgumentParser(
description="Given each file paths with such format as "
"<key>:<file>:<type>. type> can be omitted and the default "
'is "str". e.g. {} '
"--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape "
"--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape "
"--output-scps text:data/text shape:data/utt2text_shape:shape "
"--scps utt2spk:data/utt2spk".format(sys.argv[0]),
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--input-scps",
type=str,
nargs="*",
action="append",
default=[],
help="Json files for the inputs", )
parser.add_argument(
"--output-scps",
type=str,
nargs="*",
action="append",
default=[],
help="Json files for the outputs", )
parser.add_argument(
"--scps",
type=str,
nargs="+",
default=[],
help="The json files except for the input and outputs", )
parser.add_argument(
"--verbose", "-V", default=1, type=int, help="Verbose option")
parser.add_argument(
"--allow-one-column",
type=strtobool,
default=False,
help="Allow one column in input scp files. "
"In this case, the value will be empty string.", )
parser.add_argument(
"--out",
"-O",
type=str,
help="The output filename. "
"If omitted, then output to sys.stdout", )
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
args.scps = [args.scps]
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
# List[List[Tuple[str, str, Callable[[str], Any], str, str]]]
input_infos = []
output_infos = []
infos = []
for lis_list, key_scps_list in [
(input_infos, args.input_scps),
(output_infos, args.output_scps),
(infos, args.scps),
]:
for key_scps in key_scps_list:
lis = []
for key_scp in key_scps:
sps = key_scp.split(":")
if len(sps) == 2:
key, scp = sps
type_func = None
type_func_str = "none"
elif len(sps) == 3:
key, scp, type_func_str = sps
fail = False
try:
# type_func: Callable[[str], Any]
# e.g. type_func_str = "int" -> type_func = int
type_func = eval(type_func_str)
except Exception:
raise RuntimeError(
"Unknown type: {}".format(type_func_str))
if not callable(type_func):
raise RuntimeError(
"Unknown type: {}".format(type_func_str))
else:
raise RuntimeError(
"Format <key>:<filepath> "
"or <key>:<filepath>:<type> "
"e.g. feat:data/feat.scp "
"or shape:data/feat.scp:shape: {}".format(key_scp))
for item in lis:
if key == item[0]:
raise RuntimeError('The key "{}" is duplicated: {} {}'.
format(key, item[3], key_scp))
lis.append((key, scp, type_func, key_scp, type_func_str))
lis_list.append(lis)
# Open scp files
input_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
for il in input_infos]
output_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
for il in output_infos]
fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in infos]
# Note(kamo): What is done here?
# The final goal is creating a JSON file such as.
# {
# "utts": {
# "sample_id1": {(omitted)},
# "sample_id2": {(omitted)},
# ....
# }
# }
#
# To reduce memory usage, reading the input text files for each lines
# and writing JSON elements per samples.
if args.out is None:
out = sys.stdout
else:
out = open(args.out, "w", encoding="utf-8")
out.write('{\n "utts": {\n')
nutt = 0
while True:
nutt += 1
# List[List[str]]
input_lines = [[f.readline() for f in fl] for fl in input_fscps]
output_lines = [[f.readline() for f in fl] for fl in output_fscps]
lines = [[f.readline() for f in fl] for fl in fscps]
# Get the first line
concat = sum(input_lines + output_lines + lines, [])
if len(concat) == 0:
break
first = concat[0]
# Sanity check: Must be sorted by the first column and have same keys
count = 0
for ls_list in (input_lines, output_lines, lines):
for ls in ls_list:
for line in ls:
if line == "" or first == "":
if line != first:
concat = sum(input_infos + output_infos + infos, [])
raise RuntimeError("The number of lines mismatch "
'between: "{}" and "{}"'.format(
concat[0][1],
concat[count][1]))
elif line.split()[0] != first.split()[0]:
concat = sum(input_infos + output_infos + infos, [])
raise RuntimeError(
"The keys are mismatch at {}th line "
'between "{}" and "{}":\n>>> {}\n>>> {}'.format(
nutt,
concat[0][1],
concat[count][1],
first.rstrip(),
line.rstrip(), ))
count += 1
# The end of file
if first == "":
if nutt != 1:
out.write("\n")
break
if nutt != 1:
out.write(",\n")
entry = {}
for inout, _lines, _infos in [
("input", input_lines, input_infos),
("output", output_lines, output_infos),
("other", lines, infos),
]:
lis = []
for idx, (line_list, info_list) in enumerate(
zip(_lines, _infos), 1):
if inout == "input":
d = {"name": "input{}".format(idx)}
elif inout == "output":
d = {"name": "target{}".format(idx)}
else:
d = {}
# info_list: List[Tuple[str, str, Callable]]
# line_list: List[str]
for line, info in zip(line_list, info_list):
sps = line.split(None, 1)
if len(sps) < 2:
if not args.allow_one_column:
raise RuntimeError(
"Format error {}th line in {}: "
' Expecting "<key> <value>":\n>>> {}'.format(
nutt, info[1], line))
uttid = sps[0]
value = ""
else:
uttid, value = sps
key = info[0]
type_func = info[2]
value = value.rstrip()
if type_func is not None:
try:
# type_func: Callable[[str], Any]
value = type_func(value)
except Exception:
logging.error(
'"{}" is an invalid function '
"for the {} th line in {}: \n>>> {}".format(
info[4], nutt, info[1], line))
raise
d[key] = value
lis.append(d)
if inout != "other":
entry[inout] = lis
else:
# If key == 'other'. only has the first item
entry.update(lis[0])
entry = json.dumps(
entry,
indent=4,
ensure_ascii=False,
sort_keys=True,
separators=(",", ": "))
# Add indent
indent = " " * 2
entry = ("\n" + indent).join(entry.split("\n"))
uttid = first.split()[0]
out.write(' "{}": {}'.format(uttid, entry))
out.write(" }\n}\n")
logging.info("{} entries in {}".format(nutt, out.name))

@ -0,0 +1,59 @@
#!/usr/bin/env bash
# koried, 10/29/2012
# Reduce a data set based on a list of turn-ids
help_message="usage: $0 srcdir turnlist destdir"
if [ $1 == "--help" ]; then
echo "${help_message}"
exit 0;
fi
if [ $# != 3 ]; then
echo "${help_message}"
exit 1;
fi
srcdir=$1
reclist=$2
destdir=$3
if [ ! -f ${srcdir}/utt2spk ]; then
echo "$0: no such file $srcdir/utt2spk"
exit 1;
fi
function do_filtering {
# assumes the utt2spk and spk2utt files already exist.
[ -f ${srcdir}/feats.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/feats.scp >${destdir}/feats.scp
[ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/wav.scp >${destdir}/wav.scp
[ -f ${srcdir}/text ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/text >${destdir}/text
[ -f ${srcdir}/utt2num_frames ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/utt2num_frames >${destdir}/utt2num_frames
[ -f ${srcdir}/spk2gender ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/spk2gender >${destdir}/spk2gender
[ -f ${srcdir}/cmvn.scp ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/cmvn.scp >${destdir}/cmvn.scp
if [ -f ${srcdir}/segments ]; then
utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/segments >${destdir}/segments
awk '{print $2;}' ${destdir}/segments | sort | uniq > ${destdir}/reco # recordings.
# The next line would override the command above for wav.scp, which would be incorrect.
[ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/reco <${srcdir}/wav.scp >${destdir}/wav.scp
[ -f ${srcdir}/reco2file_and_channel ] && \
utils/filter_scp.pl ${destdir}/reco <${srcdir}/reco2file_and_channel >${destdir}/reco2file_and_channel
# Filter the STM file for proper sclite scoring (this will also remove the comments lines)
[ -f ${srcdir}/stm ] && utils/filter_scp.pl ${destdir}/reco < ${srcdir}/stm > ${destdir}/stm
rm ${destdir}/reco
fi
srcutts=$(wc -l < ${srcdir}/utt2spk)
destutts=$(wc -l < ${destdir}/utt2spk)
echo "Reduced #utt from $srcutts to $destutts"
}
mkdir -p ${destdir}
# filter the utt2spk based on the set of recordings
utils/filter_scp.pl ${reclist} < ${srcdir}/utt2spk > ${destdir}/utt2spk
utils/utt2spk_to_spk2utt.pl < ${destdir}/utt2spk > ${destdir}/spk2utt
do_filtering;

@ -0,0 +1,62 @@
#!/usr/bin/env bash
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
. ./path.sh
maxframes=2000
minframes=10
maxchars=200
minchars=0
nlsyms=""
no_feat=false
trans_type=char
help_message="usage: $0 olddatadir newdatadir"
. utils/parse_options.sh || exit 1;
if [ $# != 2 ]; then
echo "${help_message}"
exit 1;
fi
sdir=$1
odir=$2
mkdir -p ${odir}/tmp
if [ ${no_feat} = true ]; then
# for machine translation
cut -d' ' -f 1 ${sdir}/text > ${odir}/tmp/reclist1
else
echo "extract utterances having less than $maxframes or more than $minframes frames"
utils/data/get_utt2num_frames.sh ${sdir}
< ${sdir}/utt2num_frames awk -v maxframes="$maxframes" '{ if ($2 < maxframes) print }' \
| awk -v minframes="$minframes" '{ if ($2 > minframes) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist1
fi
echo "extract utterances having less than $maxchars or more than $minchars characters"
# counting number of chars. Use (NF - 1) instead of NF to exclude the utterance ID column
if [ -z ${nlsyms} ]; then
text2token.py -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \
| awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \
| awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist2
else
text2token.py -l ${nlsyms} -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \
| awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \
| awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist2
fi
# extract common lines
comm -12 <(sort ${odir}/tmp/reclist1) <(sort ${odir}/tmp/reclist2) > ${odir}/tmp/reclist
reduce_data_dir.sh ${sdir} ${odir}/tmp/reclist ${odir}
utils/fix_data_dir.sh ${odir}
oldnum=$(wc -l ${sdir}/feats.scp | awk '{print $1}')
newnum=$(wc -l ${odir}/feats.scp | awk '{print $1}')
echo "change from $oldnum to $newnum"

@ -0,0 +1,129 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--nchar",
"-n",
default=1,
type=int,
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2", )
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns")
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", )
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "phn"],
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
If trans_type is char,
read from SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """, )
return parser
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin
if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout
if is_python2 else sys.stdout.buffer)
line = f.readline()
n = args.nchar
while line:
x = line.split()
print(" ".join(x[:args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols:])
# get all matched positions
match_pos = []
for r in rs:
i = 0
while i >= 0:
m = r.search(a, i)
if m:
match_pos.append([m.start(), m.end()])
i = m.end()
else:
break
if args.trans_type == "phn":
a = a.split(" ")
else:
if len(match_pos) > 0:
chars = []
i = 0
while i < len(a):
start_pos, end_pos = exist_or_not(i, match_pos)
if start_pos is not None:
chars.append(a[start_pos:end_pos])
i = end_pos
else:
chars.append(a[i])
i += 1
a = chars
a = [a[j:j + n] for j in range(0, len(a), n)]
a_flat = []
for z in a:
a_flat.append("".join(z))
a_chars = [z.replace(" ", args.space) for z in a_flat]
if args.trans_type == "phn":
a_chars = [z.replace("sil", args.space) for z in a_chars]
print(" ".join(a_chars))
line = f.readline()
if __name__ == "__main__":
main()
Loading…
Cancel
Save