Merge pull request #929 from PaddlePaddle/join_ctc

[lm] transformer lm & kaldi data process
pull/945/head
Hui Zhang 3 years ago committed by GitHub
commit e8bc9a2a08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

2
.gitignore vendored

@ -24,5 +24,7 @@ tools/montreal-forced-aligner/
tools/Montreal-Forced-Aligner/ tools/Montreal-Forced-Aligner/
tools/sctk tools/sctk
tools/sctk-20159b5/ tools/sctk-20159b5/
tools/kaldi
tools/OpenBLAS/
*output/ *output/

@ -24,6 +24,7 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf # from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load # from espnet.asr.asr_utils import torch_load
@ -48,6 +49,21 @@ def load_trained_model(args):
model = exp.model model = exp.model
return model, char_list, exp, confs return model, char_list, exp, confs
def get_config(config_path):
stream = open(config_path, mode='r', encoding="utf-8")
config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close()
return config
def load_trained_lm(args):
lm_args = get_config(args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module)
lm = lm_class(lm_args.model)
model_dict = paddle.load(args.rnnlm)
lm.set_state_dict(model_dict)
return lm
def recog_v2(args): def recog_v2(args):
"""Decode with custom models that implements ScorerInterface. """Decode with custom models that implements ScorerInterface.
@ -78,12 +94,7 @@ def recog_v2(args):
preprocess_args={"train": False}, ) preprocess_args={"train": False}, )
if args.rnnlm: if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) lm = load_trained_lm(args)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(char_list), lm_args)
torch_load(args.rnnlm, lm)
lm.eval() lm.eval()
else: else:
lm = None lm = None

@ -21,9 +21,6 @@ from distutils.util import strtobool
import configargparse import configargparse
import numpy as np import numpy as np
from deepspeech.decoders.recog import recog_v2
def get_parser(): def get_parser():
"""Get default arguments.""" """Get default arguments."""
parser = configargparse.ArgumentParser( parser = configargparse.ArgumentParser(
@ -359,7 +356,7 @@ def main(args):
if args.num_encs == 1: if args.num_encs == 1:
# Experimental API that supports custom LMs # Experimental API that supports custom LMs
if args.api == "v2": if args.api == "v2":
from deepspeech.decoders.recog import recog_v2
recog_v2(args) recog_v2(args)
else: else:
raise ValueError("Only support --api v2") raise ValueError("Only support --api v2")

@ -318,6 +318,18 @@ class CTCPrefixScore():
r[0, 0] = xs[0] r[0, 0] = xs[0]
r[0, 1] = self.logzero r[0, 1] = self.logzero
else: else:
# Although the code does not exactly follow Algorithm 2,
# we don't have to change it because we can assume
# r_t(h)=0 for t < |h| in CTC forward computation
# (Note: we assume here that index t starts with 0).
# The purpose of this difference is to reduce the number of for-loops.
# https://github.com/espnet/espnet/pull/3655
# where we start to accumulate r_t(h) from t=|h|
# and iterate r_t(h) = (r_{t-1}(h) + ...) to T-1,
# avoiding accumulating zeros for t=1~|h|-1.
# Thus, we need to set r_{|h|-1}(h) = 0,
# i.e., r[output_length-1] = logzero, for initialization.
# This is just for reducing the computation.
r[output_length - 1] = self.logzero r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label # prepare forward probabilities for the last label

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the data augmentation pipeline.""" """Contains the data augmentation pipeline."""
import json import json
import os
from collections.abc import Sequence from collections.abc import Sequence
from inspect import signature from inspect import signature
from pprint import pformat from pprint import pformat
@ -90,9 +91,8 @@ class AugmentationPipeline():
effect. effect.
Params: Params:
augmentation_config(str): Augmentation configuration in json string. preprocess_conf(str): Augmentation configuration in `json file` or `json string`.
random_seed(int): Random seed. random_seed(int): Random seed.
train(bool): whether is train mode.
Raises: Raises:
ValueError: If the augmentation json config is in incorrect format". ValueError: If the augmentation json config is in incorrect format".
@ -100,11 +100,18 @@ class AugmentationPipeline():
SPEC_TYPES = {'specaug'} SPEC_TYPES = {'specaug'}
def __init__(self, augmentation_config: str, random_seed: int=0): def __init__(self, preprocess_conf: str, random_seed: int=0):
self._rng = np.random.RandomState(random_seed) self._rng = np.random.RandomState(random_seed)
self.conf = {'mode': 'sequential', 'process': []} self.conf = {'mode': 'sequential', 'process': []}
if augmentation_config: if preprocess_conf:
process = json.loads(augmentation_config) if os.path.isfile(preprocess_conf):
# json file
with open(preprocess_conf, 'r') as fin:
json_string = fin.read()
else:
# json string
json_string = preprocess_conf
process = json.loads(json_string)
self.conf['process'] += process self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all') self._augmentors, self._rates = self._parse_pipeline_from('all')

@ -105,7 +105,7 @@ class SpeechCollatorBase():
self._local_data = TarLocalData(tar2info={}, tar2object={}) self._local_data = TarLocalData(tar2info={}, tar2object={})
self.augmentation = AugmentationPipeline( self.augmentation = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed) preprocess_conf=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer( self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None mean_std_filepath) if mean_std_filepath else None

@ -17,7 +17,7 @@ import kaldiio
import numpy as np import numpy as np
import soundfile import soundfile
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["LoadInputsAndTargets"] __all__ = ["LoadInputsAndTargets"]
@ -66,8 +66,7 @@ class LoadInputsAndTargets():
raise ValueError("Only asr are allowed: mode={}".format(mode)) raise ValueError("Only asr are allowed: mode={}".format(mode))
if preprocess_conf is not None: if preprocess_conf is not None:
with open(preprocess_conf, 'r') as fin: self.preprocessing = Transformation(preprocess_conf)
self.preprocessing = AugmentationPipeline(fin.read())
logger.warning( logger.warning(
"[Experimental feature] Some preprocessing will be done " "[Experimental feature] Some preprocessing will be done "
"for the mini-batch creation using {}".format( "for the mini-batch creation using {}".format(

@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import
class ASRInterface: class ASRInterface:
"""ASR Interface for ESPnet model implementation.""" """ASR Interface model implementation."""
@staticmethod @staticmethod
def add_arguments(parser): def add_arguments(parser):
@ -103,14 +103,14 @@ class ASRInterface:
@property @property
def attention_plot_class(self): def attention_plot_class(self):
"""Get attention plot class.""" """Get attention plot class."""
from espnet.asr.asr_utils import PlotAttentionReport from deepspeech.training.extensions.plot import PlotAttentionReport
return PlotAttentionReport return PlotAttentionReport
@property @property
def ctc_plot_class(self): def ctc_plot_class(self):
"""Get CTC plot class.""" """Get CTC plot class."""
from espnet.asr.asr_utils import PlotCTCReport from deepspeech.training.extensions.plot import PlotCTCReport
return PlotCTCReport return PlotCTCReport

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,263 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import List
from typing import Tuple
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.models.lm_interface import LMInterface
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.mask import subsequent_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def __init__(
self,
n_vocab: int,
pos_enc: str=None,
embed_unit: int=128,
att_unit: int=256,
head: int=2,
unit: int=1024,
layer: int=4,
dropout_rate: float=0.5,
emb_dropout_rate: float=0.0,
att_dropout_rate: float=0.0,
tie_weights: bool=False,
**kwargs):
nn.Layer.__init__(self)
if pos_enc == "sinusoidal":
pos_enc_layer_type = "abs_pos"
elif pos_enc is None:
pos_enc_layer_type = "no_pos"
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(n_vocab, embed_unit)
if emb_dropout_rate == 0.0:
self.embed_drop = None
else:
self.embed_drop = nn.Dropout(emb_dropout_rate)
self.encoder = TransformerEncoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
attention_dropout_rate=att_dropout_rate,
input_layer="linear",
pos_enc_layer_type=pos_enc_layer_type,
concat_after=False,
static_chunk_size=1,
use_dynamic_chunk=False,
use_dynamic_left_chunk=False)
self.decoder = nn.Linear(att_unit, n_vocab)
logger.info("Tie weights set to {}".format(tie_weights))
logger.info("Dropout set to {}".format(dropout_rate))
logger.info("Emb Dropout set to {}".format(emb_dropout_rate))
logger.info("Att Dropout set to {}".format(att_dropout_rate))
if tie_weights:
assert (
att_unit == embed_unit
), "Tie Weights: True need embedding and final dimensions to match"
self.decoder.weight = self.embed.weight
def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def forward(self, x: paddle.Tensor, t: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute LM loss value from buffer sequences.
Args:
x (paddle.Tensor): Input ids. (batch, len)
t (paddle.Tensor): Target ids. (batch, len)
Returns:
tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
xm = x != 0
xlen = xm.sum(axis=1)
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(x))
else:
emb = self.embed(x)
h, _ = self.encoder(emb, xlen)
y = self.decoder(h)
loss = F.cross_entropy(
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
mask = xm.to(dtype=loss.dtype)
logp = loss * mask.view(-1)
logp = logp.sum()
count = mask.sum()
return logp / count, logp, count
# beam search API (see ScorerInterface)
def score(self, y: paddle.Tensor, state: Any,
x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
"""Score new token.
Args:
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
state: Scorer state for prefix tokens
x (paddle.Tensor): encoder feature that generates ys.
Returns:
tuple[paddle.Tensor, Any]: Tuple of
paddle.float32 scores for next token (n_vocab)
and next state for ys
"""
y = y.unsqueeze(0)
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(y))
else:
emb = self.embed(y)
h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = F.log_softmax(h).squeeze(0)
return logp, cache
# batch beam search API (see BatchScorerInterface)
def batch_score(self,
ys: paddle.Tensor,
states: List[Any],
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
paddle.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(ys))
else:
emb = self.embed(ys)
# batch decoding
h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = F.log_softmax(h)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)]
return logp, state_list
if __name__ == "__main__":
tlm = TransformerLM(
n_vocab=5002,
pos_enc=None,
embed_unit=128,
att_unit=512,
head=8,
unit=2048,
layer=16,
dropout_rate=0.5, )
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
paddle.set_device("cpu")
model_dict = paddle.load("transformerLM.pdparams")
tlm.set_state_dict(model_dict)
tlm.eval()
#Test the score
input2 = np.array([5])
input2 = paddle.to_tensor(input2)
state = None
output, state = tlm.score(input2, state, None)
input3 = np.array([5, 10])
input3 = paddle.to_tensor(input3)
output, state = tlm.score(input3, state, None)
input4 = np.array([5, 10, 0])
input4 = paddle.to_tensor(input4)
output, state = tlm.score(input4, state, None)
print("output", output)
"""
#Test the batch score
batch_size = 2
inp2 = np.array([[5], [10]])
inp2 = paddle.to_tensor(inp2)
output, states = tlm.batch_score(
inp2, [(None,None,0)] * batch_size)
inp3 = np.array([[100], [30]])
inp3 = paddle.to_tensor(inp3)
output, states = tlm.batch_score(
inp3, states)
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""

@ -0,0 +1,82 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language model interface."""
import argparse
from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
from deepspeech.utils.dynamic_import import dynamic_import
class LMInterface(ScorerInterface):
"""LM Interface model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to command line argument parser."""
return parser
@classmethod
def build(cls, n_vocab: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of vocabulary.
Returns:
LMinterface: A new instance of LMInterface.
"""
args = argparse.Namespace(**kwargs)
return cls(n_vocab, args)
def forward(self, x, t):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
raise NotImplementedError("forward method is not implemented")
predefined_lms = {
"transformer": "deepspeech.models.lm.transformer:TransformerLM",
}
def dynamic_import_lm(module):
"""Import LM class dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_lms`
Returns:
type: LM class
"""
model_class = dynamic_import(module, predefined_lms)
assert issubclass(model_class,
LMInterface), f"{module} does not implement LMInterface"
return model_class

@ -0,0 +1,75 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
from .asr_interface import ASRInterface
from deepspeech.utils.dynamic_import import dynamic_import
class STInterface(ASRInterface):
"""ST Interface model implementation.
NOTE: This class is inherited from ASRInterface to enable joint translation
and recognition when performing multi-task learning with the ASR task.
"""
def translate(self,
x,
trans_args,
char_list=None,
rnnlm=None,
ensemble_models=[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("translate method is not implemented")
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
predefined_st = {
"transformer": "deepspeech.models.u2_st:U2STModel",
}
def dynamic_import_st(module):
"""Import ST models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_st`
Returns:
type: ST class
"""
model_class = dynamic_import(module, predefined_st)
assert issubclass(model_class,
STInterface), f"{module} does not implement STInterface"
return model_class

@ -0,0 +1,15 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2_st import U2STInferModel
from .u2_st import U2STModel

@ -22,10 +22,52 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["PositionalEncoding", "RelPositionalEncoding"] __all__ = [
"PositionalEncodingInterface", "NoPositionalEncoding", "PositionalEncoding",
"RelPositionalEncoding"
]
class PositionalEncoding(nn.Layer): class PositionalEncodingInterface:
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
raise NotImplementedError("forward method is not implemented")
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
""" For getting encoding in a streaming fashion
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding
"""
raise NotImplementedError("position_encoding method is not implemented")
class NoPositionalEncoding(nn.Layer, PositionalEncodingInterface):
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int=5000,
reverse: bool=False):
nn.Layer.__init__(self)
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
return x, None
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
return None
class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
def __init__(self, def __init__(self,
d_model: int, d_model: int,
dropout_rate: float, dropout_rate: float,
@ -40,7 +82,7 @@ class PositionalEncoding(nn.Layer):
max_len (int, optional): maximum input length. Defaults to 5000. max_len (int, optional): maximum input length. Defaults to 5000.
reverse (bool, optional): Not used. Defaults to False. reverse (bool, optional): Not used. Defaults to False.
""" """
super().__init__() nn.Layer.__init__(self)
self.d_model = d_model self.d_model = d_model
self.max_len = max_len self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
@ -85,7 +127,7 @@ class PositionalEncoding(nn.Layer):
offset (int): start offset offset (int): start offset
size (int): requried size of position encoding size (int): requried size of position encoding
Returns: Returns:
paddle.Tensor: Corresponding encoding paddle.Tensor: Corresponding position encoding
""" """
assert offset + size < self.max_len assert offset + size < self.max_len
return self.dropout(self.pe[:, offset:offset + size]) return self.dropout(self.pe[:, offset:offset + size])

@ -24,6 +24,7 @@ from deepspeech.modules.activation import get_activation
from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.attention import RelPositionMultiHeadedAttention from deepspeech.modules.attention import RelPositionMultiHeadedAttention
from deepspeech.modules.conformer_convolution import ConvolutionModule from deepspeech.modules.conformer_convolution import ConvolutionModule
from deepspeech.modules.embedding import NoPositionalEncoding
from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.embedding import RelPositionalEncoding from deepspeech.modules.embedding import RelPositionalEncoding
from deepspeech.modules.encoder_layer import ConformerEncoderLayer from deepspeech.modules.encoder_layer import ConformerEncoderLayer
@ -76,7 +77,7 @@ class BaseEncoder(nn.Layer):
input_layer (str): input layer type. input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8] optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str): Encoder positional encoding layer type. pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos] opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool): normalize_before (bool):
True: use layer_norm before each sub-block of a layer. True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer. False: use layer_norm after each sub-block of a layer.
@ -101,6 +102,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos": elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else: else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@ -370,6 +373,41 @@ class TransformerEncoder(BaseEncoder):
concat_after=concat_after) for _ in range(num_blocks) concat_after=concat_after) for _ in range(num_blocks)
]) ])
def forward_one_step(
self,
xs: paddle.Tensor,
masks: paddle.Tensor,
cache=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encode input frame.
Args:
xs (paddle.Tensor): (Prefix) Input tensor. (B, T, D)
masks (paddle.Tensor): Mask tensor. (B, T, T)
cache (List[paddle.Tensor]): List of cache tensors.
Returns:
paddle.Tensor: Output tensor.
paddle.Tensor: Mask tensor.
List[paddle.Tensor]: List of new cache tensors.
"""
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
if cache is None:
cache = [None for _ in range(len(self.encoders))]
new_cache = []
for c, e in zip(cache, self.encoders):
xs, masks, _ = e(xs, masks, output_cache=c)
new_cache.append(xs)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache
class ConformerEncoder(BaseEncoder): class ConformerEncoder(BaseEncoder):
"""Conformer encoder module.""" """Conformer encoder module."""

@ -71,7 +71,7 @@ class TransformerEncoderLayer(nn.Layer):
self, self,
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: Optional[paddle.Tensor]=None,
mask_pad: Optional[paddle.Tensor]=None, mask_pad: Optional[paddle.Tensor]=None,
output_cache: Optional[paddle.Tensor]=None, output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None, cnn_cache: Optional[paddle.Tensor]=None,
@ -82,8 +82,8 @@ class TransformerEncoderLayer(nn.Layer):
mask (paddle.Tensor): Mask tensor for the input (#batch, time). mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer to ConformerEncoderLayer
mask_pad (paddle.Tensor): does not used in transformer layer, mask_pad (paddle.Tensor): not used here, it's for interface
just for unified api with conformer. compatibility to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x. (#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface cnn_cache (paddle.Tensor): not used here, it's for interface

@ -60,7 +60,8 @@ class LinearNoSubsampling(BaseSubsampling):
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(idim, odim), nn.Linear(idim, odim),
nn.LayerNorm(odim, epsilon=1e-12), nn.LayerNorm(odim, epsilon=1e-12),
nn.Dropout(dropout_rate), ) nn.Dropout(dropout_rate),
nn.ReLU(), )
self.right_context = 0 self.right_context = 0
self.subsampling_rate = 1 self.subsampling_rate = 1
@ -83,7 +84,12 @@ class LinearNoSubsampling(BaseSubsampling):
return x, pos_emb, x_mask return x, pos_emb, x_mask
class Conv2dSubsampling4(BaseSubsampling): class Conv2dSubsampling(BaseSubsampling):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class Conv2dSubsampling4(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).""" """Convolutional 2D subsampling (to 1/4 length)."""
def __init__(self, def __init__(self,
@ -134,7 +140,7 @@ class Conv2dSubsampling4(BaseSubsampling):
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
class Conv2dSubsampling6(BaseSubsampling): class Conv2dSubsampling6(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/6 length).""" """Convolutional 2D subsampling (to 1/6 length)."""
def __init__(self, def __init__(self,
@ -187,7 +193,7 @@ class Conv2dSubsampling6(BaseSubsampling):
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
class Conv2dSubsampling8(BaseSubsampling): class Conv2dSubsampling8(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/8 length).""" """Convolutional 2D subsampling (to 1/8 length)."""
def __init__(self, def __init__(self,

@ -0,0 +1,418 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import numpy as np
from . import extension
class PlotAttentionReport(extension.Extension):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
att_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1, ):
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of att_ws matrix."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
self.outdir, uttid_list[idx], i + 1, )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
self.outdir, uttid_list[idx], i + 1, )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w,
filename.format(trainer))
# han
for idx, att_w in enumerate(att_ws[num_encs]):
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
self.outdir, uttid_list[idx], )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(
att_w, filename.format(trainer), han_mode=True)
else:
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
uttid_list[idx], )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
def log_attentions(self, logger, step):
"""Add image files of att_ws matrix to the tensorboard."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure(
"%s_att%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step, )
# han
for idx, att_w in enumerate(att_ws[num_encs]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_han_plot(att_w)
logger.add_figure(
"%s_han" % (uttid_list[idx]),
plot.gcf(),
step, )
else:
for idx, att_w in enumerate(att_ws):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_attention_weights(self):
"""Return attention weights.
Returns:
numpy.ndarray: attention weights. float. Its shape would be
differ from backend.
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
* chainer-> (B, Lmax, Tmax)
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
else:
att_ws = self.att_vis_fn(**batch)
return att_ws, uttid_list
def trim_attention_weight(self, uttid, att_w):
"""Transform attention matrix with regard to self.reverse."""
if self.reverse:
enc_key, enc_axis = self.okey, self.oaxis
dec_key, dec_axis = self.ikey, self.iaxis
else:
enc_key, enc_axis = self.ikey, self.iaxis
dec_key, dec_axis = self.okey, self.oaxis
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
if len(att_w.shape) == 3:
att_w = att_w[:, :dec_len, :enc_len]
else:
att_w = att_w[:dec_len, :enc_len]
return att_w
def draw_attention_plot(self, att_w):
"""Plot the att_w matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
att_w = att_w.astype(np.float32)
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
plt.subplot(1, len(att_w), h)
plt.imshow(aw, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
else:
plt.imshow(att_w, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
plt.tight_layout()
return plt
def draw_han_plot(self, att_w):
"""Plot the att_w matrix for hierarchical attention.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
legends = []
plt.subplot(1, len(att_w), h)
for i in range(aw.shape[1]):
plt.plot(aw[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, aw.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
else:
legends = []
for i in range(att_w.shape[1]):
plt.plot(att_w[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, att_w.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
plt.tight_layout()
return plt
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
if han_mode:
plt = self.draw_han_plot(att_w)
else:
plt = self.draw_attention_plot(att_w)
plt.savefig(filename)
plt.close()
class PlotCTCReport(extension.Extension):
"""Plot CTC reporter.
Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
ctc_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1, ):
self.ctc_vis_fn = ctc_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of ctc prob."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
self.outdir, uttid_list[idx], i + 1, )
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
self.outdir, uttid_list[idx], i + 1, )
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
else:
for idx, ctc_prob in enumerate(ctc_probs):
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
uttid_list[idx], )
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
def log_ctc_probs(self, logger, step):
"""Add image files of ctc probs to the tensorboard."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure(
"%s_ctc%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step, )
else:
for idx, ctc_prob in enumerate(ctc_probs):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_ctc_probs(self):
"""Return CTC probs.
Returns:
numpy.ndarray: CTC probs. float. Its shape would be
differ from backend. (B, Tmax, vocab).
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
probs = self.ctc_vis_fn(*batch)
else:
probs = self.ctc_vis_fn(**batch)
return probs, uttid_list
def trim_ctc_prob(self, uttid, prob):
"""Trim CTC posteriors accoding to input lengths."""
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
prob = prob[:enc_len]
return prob
def draw_ctc_plot(self, ctc_prob):
"""Plot the ctc_prob matrix.
Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
ctc_prob = ctc_prob.astype(np.float32)
plt.clf()
topk_ids = np.argsort(ctc_prob, axis=1)
n_frames, vocab = ctc_prob.shape
times_probs = np.arange(n_frames)
plt.figure(figsize=(20, 8))
# NOTE: index 0 is reserved for blank
for idx in set(topk_ids.reshape(-1).tolist()):
if idx == 0:
plt.plot(
times_probs,
ctc_prob[:, 0],
":",
label="<blank>",
color="grey")
else:
plt.plot(times_probs, ctc_prob[:, idx])
plt.xlabel(u"Input [frame]", fontsize=12)
plt.ylabel("Posteriors", fontsize=12)
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
plt.yticks(list(range(0, 2, 1)))
plt.tight_layout()
return plt
def _plot_and_save_ctc(self, ctc_prob, filename):
plt = self.draw_ctc_plot(ctc_prob)
plt.savefig(filename)
plt.close()

@ -11,18 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger

@ -0,0 +1,61 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..reporter import DictSummary
from .utils import get_trigger
class CompareValueTrigger():
"""Trigger invoked when key value getting bigger or lower than before.
Args:
key (str) : Key of value.
compare_fn ((float, float) -> bool) : Function to compare the values.
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
self._key = key
self._best_value = None
self._interval_trigger = get_trigger(trigger)
self._init_summary()
self._compare_fn = compare_fn
def __call__(self, trainer):
"""Get value related to the key and compare with current value."""
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
self._init_summary()
if self._best_value is None:
# initialize best value
self._best_value = value
return False
elif self._compare_fn(self._best_value, value):
return True
else:
self._best_value = value
return False
def _init_summary(self):
self._summary = DictSummary()

@ -30,3 +30,12 @@ class TimeTrigger():
return True return True
else: else:
return False return False
def state_dict(self):
state_dict = {
"next_time": self._next_time,
}
return state_dict
def set_state_dict(self, state_dict):
self._next_time = state_dict['next_time']

@ -0,0 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,41 @@
import numpy as np
def delta(feat, window):
assert window > 0
delta_feat = np.zeros_like(feat)
for i in range(1, window + 1):
delta_feat[:-i] += i * feat[i:]
delta_feat[i:] += -i * feat[:-i]
delta_feat[-i:] += i * feat[-1]
delta_feat[:i] += -i * feat[0]
delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1))
return delta_feat
def add_deltas(x, window=2, order=2):
"""
Args:
x (np.ndarray): speech feat, (T, D).
Return:
np.ndarray: (T, (1+order)*D)
"""
feats = [x]
for _ in range(order):
feats.append(delta(feats[-1], window))
return np.concatenate(feats, axis=1)
class AddDeltas():
def __init__(self, window=2, order=2):
self.window = window
self.order = order
def __repr__(self):
return "{name}(window={window}, order={order}".format(
name=self.__class__.__name__, window=self.window, order=self.order
)
def __call__(self, x):
return add_deltas(x, window=self.window, order=self.order)

@ -0,0 +1,45 @@
import numpy
class ChannelSelector():
"""Select 1ch from multi-channel signal"""
def __init__(self, train_channel="random", eval_channel=0, axis=1):
self.train_channel = train_channel
self.eval_channel = eval_channel
self.axis = axis
def __repr__(self):
return (
"{name}(train_channel={train_channel}, "
"eval_channel={eval_channel}, axis={axis})".format(
name=self.__class__.__name__,
train_channel=self.train_channel,
eval_channel=self.eval_channel,
axis=self.axis,
)
)
def __call__(self, x, train=True):
# Assuming x: [Time, Channel] by default
if x.ndim <= self.axis:
# If the dimension is insufficient, then unsqueeze
# (e.g [Time] -> [Time, 1])
ind = tuple(
slice(None) if i < x.ndim else None for i in range(self.axis + 1)
)
x = x[ind]
if train:
channel = self.train_channel
else:
channel = self.eval_channel
if channel == "random":
ch = numpy.random.randint(0, x.shape[self.axis])
else:
ch = channel
ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim))
return x[ind]

@ -0,0 +1,158 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import h5py
import kaldiio
import numpy as np
class CMVN():
"Apply Global/Spk CMVN/iverserCMVN."
def __init__(
self,
stats,
norm_means=True,
norm_vars=False,
filetype="mat",
utt2spk=None,
spk2utt=None,
reverse=False,
std_floor=1.0e-20, ):
self.stats_file = stats
self.norm_means = norm_means
self.norm_vars = norm_vars
self.reverse = reverse
if isinstance(stats, dict):
stats_dict = dict(stats)
else:
# Use for global CMVN
if filetype == "mat":
stats_dict = {None: kaldiio.load_mat(stats)}
# Use for global CMVN
elif filetype == "npy":
stats_dict = {None: np.load(stats)}
# Use for speaker CMVN
elif filetype == "ark":
self.accept_uttid = True
stats_dict = dict(kaldiio.load_ark(stats))
# Use for speaker CMVN
elif filetype == "hdf5":
self.accept_uttid = True
stats_dict = h5py.File(stats)
else:
raise ValueError("Not supporting filetype={}".format(filetype))
if utt2spk is not None:
self.utt2spk = {}
with io.open(utt2spk, "r", encoding="utf-8") as f:
for line in f:
utt, spk = line.rstrip().split(None, 1)
self.utt2spk[utt] = spk
elif spk2utt is not None:
self.utt2spk = {}
with io.open(spk2utt, "r", encoding="utf-8") as f:
for line in f:
spk, utts = line.rstrip().split(None, 1)
for utt in utts.split():
self.utt2spk[utt] = spk
else:
self.utt2spk = None
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
# and the first vector contains the sum of feats and the second is
# the sum of squares. The last value of the first, i.e. stats[0,-1],
# is the number of samples for this statistics.
self.bias = {}
self.scale = {}
for spk, stats in stats_dict.items():
assert len(stats) == 2, stats.shape
count = stats[0, -1]
# If the feature has two or more dimensions
if not (np.isscalar(count) or isinstance(count, (int, float))):
# The first is only used
count = count.flatten()[0]
mean = stats[0, :-1] / count
# V(x) = E(x^2) - (E(x))^2
var = stats[1, :-1] / count - mean * mean
std = np.maximum(np.sqrt(var), std_floor)
self.bias[spk] = -mean
self.scale[spk] = 1 / std
def __repr__(self):
return ("{name}(stats_file={stats_file}, "
"norm_means={norm_means}, norm_vars={norm_vars}, "
"reverse={reverse})".format(
name=self.__class__.__name__,
stats_file=self.stats_file,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
reverse=self.reverse, ))
def __call__(self, x, uttid=None):
if self.utt2spk is not None:
spk = self.utt2spk[uttid]
else:
spk = uttid
if not self.reverse:
# apply cmvn
if self.norm_means:
x = np.add(x, self.bias[spk])
if self.norm_vars:
x = np.multiply(x, self.scale[spk])
else:
# apply reverse cmvn
if self.norm_vars:
x = np.divide(x, self.scale[spk])
if self.norm_means:
x = np.subtract(x, self.bias[spk])
return x
class UtteranceCMVN():
"Apply Utterance CMVN"
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
def __repr__(self):
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
name=self.__class__.__name__,
norm_means=self.norm_means,
norm_vars=self.norm_vars, )
def __call__(self, x, uttid=None):
# x: [Time, Dim]
square_sums = (x**2).sum(axis=0)
mean = x.mean(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean**2
std = np.maximum(np.sqrt(var), self.std_floor)
x = np.divide(x, std)
return x

@ -0,0 +1,71 @@
import inspect
from deepspeech.transform.transform_interface import TransformInterface
from deepspeech.utils.check_kwargs import check_kwargs
class FuncTrans(TransformInterface):
"""Functional Transformation
WARNING:
Builtin or C/C++ functions may not work properly
because this class heavily depends on the `inspect` module.
Usage:
>>> def foo_bar(x, a=1, b=2):
... '''Foo bar
... :param x: input
... :param int a: default 1
... :param int b: default 2
... '''
... return x + a - b
>>> class FooBar(FuncTrans):
... _func = foo_bar
... __doc__ = foo_bar.__doc__
"""
_func = None
def __init__(self, **kwargs):
self.kwargs = kwargs
check_kwargs(self.func, kwargs)
def __call__(self, x):
return self.func(x, **self.kwargs)
@classmethod
def add_arguments(cls, parser):
fname = cls._func.__name__.replace("_", "-")
group = parser.add_argument_group(fname + " transformation setting")
for k, v in cls.default_params().items():
# TODO(karita): get help and choices from docstring?
attr = k.replace("_", "-")
group.add_argument(f"--{fname}-{attr}", default=v, type=type(v))
return parser
@property
def func(self):
return type(self)._func
@classmethod
def default_params(cls):
try:
d = dict(inspect.signature(cls._func).parameters)
except ValueError:
d = dict()
return {
k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty
}
def __repr__(self):
params = self.default_params()
params.update(**self.kwargs)
ret = self.__class__.__name__ + "("
if len(params) == 0:
return ret + ")"
for k, v in params.items():
ret += "{}={}, ".format(k, v)
return ret[:-2] + ")"

@ -0,0 +1,343 @@
import librosa
import numpy
import scipy
import soundfile
from deepspeech.io.reader import SoundHDF5File
class SpeedPerturbation():
"""SpeedPerturbation
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
Warning:
This function is very slow because of resampling.
I recommmend to apply speed-perturb outside the training using sox.
"""
def __init__(
self,
lower=0.9,
upper=1.1,
utt2ratio=None,
keep_length=True,
res_type="kaiser_best",
seed=None,
):
self.res_type = res_type
self.keep_length = keep_length
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
self.utt2ratio = {}
# Use the scheduled ratio for each utterances
self.utt2ratio_file = utt2ratio
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
self.utt2ratio = None
# The ratio is given on runtime randomly
self.lower = lower
self.upper = upper
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format(
self.__class__.__name__,
self.lower,
self.upper,
self.keep_length,
self.res_type,
)
else:
return "{}({}, res_type={})".format(
self.__class__.__name__, self.utt2ratio_file, self.res_type
)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
# Note1: resample requires the sampling-rate of input and output,
# but actually only the ratio is used.
y = librosa.resample(x, ratio, 1, res_type=self.res_type)
if self.keep_length:
diff = abs(len(x) - len(y))
if len(y) > len(x):
# Truncate noise
y = y[diff // 2 : -((diff + 1) // 2)]
elif len(y) < len(x):
# Assume the time-axis is the first: (Time, Channel)
pad_width = [(diff // 2, (diff + 1) // 2)] + [
(0, 0) for _ in range(y.ndim - 1)
]
y = numpy.pad(
y, pad_width=pad_width, constant_values=0, mode="constant"
)
return y
class BandpassPerturbation():
"""BandpassPerturbation
Randomly dropout along the frequency axis.
The original idea comes from the following:
"randomly-selected frequency band was cut off under the constraint of
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
everyday home environments using multiple microphone arrays;
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
"""
def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)):
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
# x_stft: (Time, Channel, Freq)
self.axes = axes
def __repr__(self):
return "{}(lower={}, upper={})".format(
self.__class__.__name__, self.lower, self.upper
)
def __call__(self, x_stft, uttid=None, train=True):
if not train:
return x_stft
if x_stft.ndim == 1:
raise RuntimeError(
"Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)"
)
ratio = self.state.uniform(self.lower, self.upper)
axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes]
shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)]
mask = self.state.randn(*shape) > ratio
x_stft *= mask
return x_stft
class VolumePerturbation():
def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None):
self.dbunit = dbunit
self.utt2ratio_file = utt2ratio
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
# Use the scheduled ratio for each utterances
self.utt2ratio = {}
self.lower = None
self.upper = None
self.accept_uttid = True
with open(utt2ratio, "r") as f:
for line in f:
utt, ratio = line.rstrip().split(None, 1)
ratio = float(ratio)
self.utt2ratio[utt] = ratio
else:
# The ratio is given on runtime randomly
self.utt2ratio = None
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit
)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit
)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if self.accept_uttid:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10 ** (ratio / 20)
return x * ratio
class NoiseInjection():
"""Add isotropic noise"""
def __init__(
self,
utt2noise=None,
lower=-20,
upper=-5,
utt2ratio=None,
filetype="list",
dbunit=True,
seed=None,
):
self.utt2noise_file = utt2noise
self.utt2ratio_file = utt2ratio
self.filetype = filetype
self.dbunit = dbunit
self.lower = lower
self.upper = upper
self.state = numpy.random.RandomState(seed)
if utt2ratio is not None:
# Use the scheduled ratio for each utterances
self.utt2ratio = {}
with open(utt2noise, "r") as f:
for line in f:
utt, snr = line.rstrip().split(None, 1)
snr = float(snr)
self.utt2ratio[utt] = snr
else:
# The ratio is given on runtime randomly
self.utt2ratio = None
if utt2noise is not None:
self.utt2noise = {}
if filetype == "list":
with open(utt2noise, "r") as f:
for line in f:
utt, filename = line.rstrip().split(None, 1)
signal, rate = soundfile.read(filename, dtype="int16")
# Load all files in memory
self.utt2noise[utt] = (signal, rate)
elif filetype == "sound.hdf5":
self.utt2noise = SoundHDF5File(utt2noise, "r")
else:
raise ValueError(filetype)
else:
self.utt2noise = None
if utt2noise is not None and utt2ratio is not None:
if set(self.utt2ratio) != set(self.utt2noise):
raise RuntimeError(
"The uttids mismatch between {} and {}".format(utt2ratio, utt2noise)
)
def __repr__(self):
if self.utt2ratio is None:
return "{}(lower={}, upper={}, dbunit={})".format(
self.__class__.__name__, self.lower, self.upper, self.dbunit
)
else:
return '{}("{}", dbunit={})'.format(
self.__class__.__name__, self.utt2ratio_file, self.dbunit
)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
# 1. Get ratio of noise to signal in sound pressure level
if uttid is not None and self.utt2ratio is not None:
ratio = self.utt2ratio[uttid]
else:
ratio = self.state.uniform(self.lower, self.upper)
if self.dbunit:
ratio = 10 ** (ratio / 20)
scale = ratio * numpy.sqrt((x ** 2).mean())
# 2. Get noise
if self.utt2noise is not None:
# Get noise from the external source
if uttid is not None:
noise, rate = self.utt2noise[uttid]
else:
# Randomly select the noise source
noise = self.state.choice(list(self.utt2noise.values()))
# Normalize the level
noise /= numpy.sqrt((noise ** 2).mean())
# Adjust the noise length
diff = abs(len(x) - len(noise))
offset = self.state.randint(0, diff)
if len(noise) > len(x):
# Truncate noise
noise = noise[offset : -(diff - offset)]
else:
noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap")
else:
# Generate white noise
noise = self.state.normal(0, 1, x.shape)
# 3. Add noise to signal
return x + noise * scale
class RIRConvolve():
def __init__(self, utt2rir, filetype="list"):
self.utt2rir_file = utt2rir
self.filetype = filetype
self.utt2rir = {}
if filetype == "list":
with open(utt2rir, "r") as f:
for line in f:
utt, filename = line.rstrip().split(None, 1)
signal, rate = soundfile.read(filename, dtype="int16")
self.utt2rir[utt] = (signal, rate)
elif filetype == "sound.hdf5":
self.utt2rir = SoundHDF5File(utt2rir, "r")
else:
raise NotImplementedError(filetype)
def __repr__(self):
return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file)
def __call__(self, x, uttid=None, train=True):
if not train:
return x
x = x.astype(numpy.float32)
if x.ndim != 1:
# Must be single channel
raise RuntimeError(
"Input x must be one dimensional array, but got {}".format(x.shape)
)
rir, rate = self.utt2rir[uttid]
if rir.ndim == 2:
# FIXME(kamo): Use chainer.convolution_1d?
# return [Time, Channel]
return numpy.stack(
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1
)
else:
return scipy.convolve(x, rir, mode="same")

@ -0,0 +1,202 @@
"""Spec Augment module for preprocessing i.e., data augmentation"""
import random
import numpy
from PIL import Image
from PIL.Image import BICUBIC
from deepspeech.transform.functional import FuncTrans
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
:param numpy.ndarray x: spectrogram (time, freq)
:param int max_time_warp: maximum time frames to warp
:param bool inplace: overwrite x with the result
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
(slow, differentiable)
:returns numpy.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC)
if inplace:
x[:warped] = left
x[warped:] = right
return x
return numpy.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
import paddle
from espnet.utils import spec_augment
# TODO(karita): make this differentiable again
return spec_augment.time_warp(paddle.to_tensor(x), window).numpy()
else:
raise NotImplementedError(
"unknown resize mode: "
+ mode
+ ", choose one from (PIL, sparse_image_warp)."
)
class TimeWarp(FuncTrans):
_func = time_warp
__doc__ = time_warp.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False):
"""freq mask for spec agument
:param numpy.ndarray x: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if inplace:
cloned = x
else:
cloned = x.copy()
num_mel_channels = cloned.shape[1]
fs = numpy.random.randint(0, F, size=(n_mask, 2))
for f, mask_end in fs:
f_zero = random.randrange(0, num_mel_channels - f)
mask_end += f_zero
# avoids randrange error if values are equal and range is empty
if f_zero == f_zero + f:
continue
if replace_with_zero:
cloned[:, f_zero:mask_end] = 0
else:
cloned[:, f_zero:mask_end] = cloned.mean()
return cloned
class FreqMask(FuncTrans):
_func = freq_mask
__doc__ = freq_mask.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False):
"""freq mask for spec agument
:param numpy.ndarray spec: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if inplace:
cloned = spec
else:
cloned = spec.copy()
len_spectro = cloned.shape[0]
ts = numpy.random.randint(0, T, size=(n_mask, 2))
for t, mask_end in ts:
# avoid randint range error
if len_spectro - t <= 0:
continue
t_zero = random.randrange(0, len_spectro - t)
# avoids randrange error if values are equal and range is empty
if t_zero == t_zero + t:
continue
mask_end += t_zero
if replace_with_zero:
cloned[t_zero:mask_end] = 0
else:
cloned[t_zero:mask_end] = cloned.mean()
return cloned
class TimeMask(FuncTrans):
_func = time_mask
__doc__ = time_mask.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)
def spec_augment(
x,
resize_mode="PIL",
max_time_warp=80,
max_freq_width=27,
n_freq_mask=2,
max_time_width=100,
n_time_mask=2,
inplace=True,
replace_with_zero=True,
):
"""spec agument
apply random time warping and time/freq masking
default setting is based on LD (Librispeech double) in Table 2
https://arxiv.org/pdf/1904.08779.pdf
:param numpy.ndarray x: (time, freq)
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
(slow, differentiable)
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
:param int freq_mask_width: maximum width of the random freq mask (F)
:param int n_freq_mask: the number of the random freq mask (m_F)
:param int time_mask_width: maximum width of the random time mask (T)
:param int n_time_mask: the number of the random time mask (m_T)
:param bool inplace: overwrite intermediate array
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
assert isinstance(x, numpy.ndarray)
assert x.ndim == 2
x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode)
x = freq_mask(
x,
max_freq_width,
n_freq_mask,
inplace=inplace,
replace_with_zero=replace_with_zero,
)
x = time_mask(
x,
max_time_width,
n_time_mask,
inplace=inplace,
replace_with_zero=replace_with_zero,
)
return x
class SpecAugment(FuncTrans):
_func = spec_augment
__doc__ = spec_augment.__doc__
def __call__(self, x, train):
if not train:
return x
return super().__call__(x)

@ -0,0 +1,307 @@
import librosa
import numpy as np
def stft(
x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect"
):
# x: [Time, Channel]
if x.ndim == 1:
single_channel = True
# x: [Time] -> [Time, Channel]
x = x[:, None]
else:
single_channel = False
x = x.astype(np.float32)
# FIXME(kamo): librosa.stft can't use multi-channel?
# x: [Time, Channel, Freq]
x = np.stack(
[
librosa.stft(
x[:, ch],
n_fft=n_fft,
hop_length=n_shift,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
).T
for ch in range(x.shape[1])
],
axis=1,
)
if single_channel:
# x: [Time, Channel, Freq] -> [Time, Freq]
x = x[:, 0]
return x
def istft(x, n_shift, win_length=None, window="hann", center=True):
# x: [Time, Channel, Freq]
if x.ndim == 2:
single_channel = True
# x: [Time, Freq] -> [Time, Channel, Freq]
x = x[:, None, :]
else:
single_channel = False
# x: [Time, Channel]
x = np.stack(
[
librosa.istft(
x[:, ch].T, # [Time, Freq] -> [Freq, Time]
hop_length=n_shift,
win_length=win_length,
window=window,
center=center,
)
for ch in range(x.shape[1])
],
axis=1,
)
if single_channel:
# x: [Time, Channel] -> [Time]
x = x[:, 0]
return x
def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
# x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
# spc: (Time, Channel, Freq) or (Time, Freq)
spc = np.abs(x_stft)
# mel_basis: (Mel_freq, Freq)
mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax)
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
return lmspc
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
return spc
def logmelspectrogram(
x,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10,
pad_mode="reflect",
):
# stft: (Time, Channel, Freq) or (Time, Freq)
x_stft = stft(
x,
n_fft=n_fft,
n_shift=n_shift,
win_length=win_length,
window=window,
pad_mode=pad_mode,
)
return stft2logmelspectrogram(
x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps
)
class Spectrogram():
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
def __repr__(self):
return (
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
)
def __call__(self, x):
return spectrogram(
x,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
class LogMelSpectrogram():
def __init__(
self,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10,
):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps,
)
)
def __call__(self, x):
return logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
)
class Stft2LogMelSpectrogram():
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps,
)
)
def __call__(self, x):
return stft2logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
)
class Stft():
def __init__(
self,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect",
):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
self.pad_mode = pad_mode
def __repr__(self):
return (
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode,
)
)
def __call__(self, x):
return stft(
x,
self.n_fft,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode,
)
class IStft():
def __init__(self, n_shift, win_length=None, window="hann", center=True):
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
def __repr__(self):
return (
"{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center})".format(
name=self.__class__.__name__,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
)
)
def __call__(self, x):
return istft(
x,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
)

@ -0,0 +1,20 @@
# TODO(karita): add this to all the transform impl.
class TransformInterface:
"""Transform Interface"""
def __call__(self, x):
raise NotImplementedError("__call__ method is not implemented")
@classmethod
def add_arguments(cls, parser):
return parser
def __repr__(self):
return self.__class__.__name__ + "()"
class Identity(TransformInterface):
"""Identity Function"""
def __call__(self, x):
return x

@ -0,0 +1,149 @@
"""Transformation module."""
from collections.abc import Sequence
from collections import OrderedDict
import copy
from inspect import signature
import io
import logging
import yaml
from deepspeech.utils.dynamic_import import dynamic_import
# TODO(karita): inherit TransformInterface
# TODO(karita): register cmd arguments in asr_train.py
import_alias = dict(
identity="deepspeech.transform.transform_interface:Identity",
time_warp="deepspeech.transform.spec_augment:TimeWarp",
time_mask="deepspeech.transform.spec_augment:TimeMask",
freq_mask="deepspeech.transform.spec_augment:FreqMask",
spec_augment="deepspeech.transform.spec_augment:SpecAugment",
speed_perturbation="deepspeech.transform.perturb:SpeedPerturbation",
volume_perturbation="deepspeech.transform.perturb:VolumePerturbation",
noise_injection="deepspeech.transform.perturb:NoiseInjection",
bandpass_perturbation="deepspeech.transform.perturb:BandpassPerturbation",
rir_convolve="deepspeech.transform.perturb:RIRConvolve",
delta="deepspeech.transform.add_deltas:AddDeltas",
cmvn="deepspeech.transform.cmvn:CMVN",
utterance_cmvn="deepspeech.transform.cmvn:UtteranceCMVN",
fbank="deepspeech.transform.spectrogram:LogMelSpectrogram",
spectrogram="deepspeech.transform.spectrogram:Spectrogram",
stft="deepspeech.transform.spectrogram:Stft",
istft="deepspeech.transform.spectrogram:IStft",
stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="deepspeech.transform.wpe:WPE",
channel_selector="deepspeech.transform.channel_selector:ChannelSelector",
)
class Transformation():
"""Apply some functions to the mini-batch
Examples:
>>> kwargs = {"process": [{"type": "fbank",
... "n_mels": 80,
... "fs": 16000},
... {"type": "cmvn",
... "stats": "data/train/cmvn.ark",
... "norm_vars": True},
... {"type": "delta", "window": 2, "order": 2}]}
>>> transform = Transformation(kwargs)
>>> bs = 10
>>> xs = [np.random.randn(100, 80).astype(np.float32)
... for _ in range(bs)]
>>> xs = transform(xs)
"""
def __init__(self, conffile=None):
if conffile is not None:
if isinstance(conffile, dict):
self.conf = copy.deepcopy(conffile)
else:
with io.open(conffile, encoding="utf-8") as f:
self.conf = yaml.safe_load(f)
assert isinstance(self.conf, dict), type(self.conf)
else:
self.conf = {"mode": "sequential", "process": []}
self.functions = OrderedDict()
if self.conf.get("mode", "sequential") == "sequential":
for idx, process in enumerate(self.conf["process"]):
assert isinstance(process, dict), type(process)
opts = dict(process)
process_type = opts.pop("type")
class_obj = dynamic_import(process_type, import_alias)
# TODO(karita): assert issubclass(class_obj, TransformInterface)
try:
self.functions[idx] = class_obj(**opts)
except TypeError:
try:
signa = signature(class_obj)
except ValueError:
# Some function, e.g. built-in function, are failed
pass
else:
logging.error(
"Expected signature: {}({})".format(
class_obj.__name__, signa
)
)
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"])
)
def __repr__(self):
rep = "\n" + "\n".join(
" {}: {}".format(k, v) for k, v in self.functions.items()
)
return "{}({})".format(self.__class__.__name__, rep)
def __call__(self, xs, uttid_list=None, **kwargs):
"""Return new mini-batch
:param Union[Sequence[np.ndarray], np.ndarray] xs:
:param Union[Sequence[str], str] uttid_list:
:return: batch:
:rtype: List[np.ndarray]
"""
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx in range(len(self.conf["process"])):
func = self.functions[idx]
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logging.fatal(
"Catch a exception from {}th func: {}".format(idx, func)
)
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"])
)
if is_batch:
return xs
else:
return xs[0]

@ -0,0 +1,45 @@
from nara_wpe.wpe import wpe
class WPE(object):
def __init__(
self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full"
):
self.taps = taps
self.delay = delay
self.iterations = iterations
self.psd_context = psd_context
self.statistics_mode = statistics_mode
def __repr__(self):
return (
"{name}(taps={taps}, delay={delay}"
"iterations={iterations}, psd_context={psd_context}, "
"statistics_mode={statistics_mode})".format(
name=self.__class__.__name__,
taps=self.taps,
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode,
)
)
def __call__(self, xs):
"""Return enhanced
:param np.ndarray xs: (Time, Channel, Frequency)
:return: enhanced_xs
:rtype: np.ndarray
"""
# nara_wpe.wpe: (F, C, T)
xs = wpe(
xs.transpose((2, 1, 0)),
taps=self.taps,
delay=self.delay,
iterations=self.iterations,
psd_context=self.psd_context,
statistics_mode=self.statistics_mode,
)
return xs.transpose(2, 1, 0)

@ -0,0 +1,52 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
__all__ = ["label_smoothing_dist"]
# TODO(takaaki-hori): add different smoothing methods
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
"""Obtain label distribution for loss smoothing.
:param odim:
:param lsm_type:
:param blank:
:param transcript:
:return:
"""
if transcript is not None:
with open(transcript, "rb") as f:
trans_json = json.load(f)["utts"]
if lsm_type == "unigram":
assert transcript is not None, (
"transcript is required for %s label smoothing" % lsm_type)
labelcount = np.zeros(odim)
for k, v in trans_json.items():
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
# to avoid an error when there is no text in an uttrance
if len(ids) > 0:
labelcount[ids] += 1
labelcount[odim - 1] = len(transcript) # count <eos>
labelcount[labelcount == 0] = 1 # flooring
labelcount[blank] = 0 # remove counts for blank
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
else:
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
sys.exit()
return labeldist

@ -14,17 +14,17 @@
"""This module provides functions to calculate bleu score in different level. """This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
import nltk
import numpy as np
import sacrebleu import sacrebleu
__all__ = ['bleu', 'char_bleu'] __all__ = ['bleu', 'char_bleu', "ErrorCalculator"]
def bleu(hypothesis, reference): def bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and """Calculate BLEU. BLEU compares reference text and
hypothesis text in word-level using scarebleu. hypothesis text in word-level using scarebleu.
:param reference: The reference sentences. :param reference: The reference sentences.
:type reference: list[list[str]] :type reference: list[list[str]]
:param hypothesis: The hypothesis sentence. :param hypothesis: The hypothesis sentence.
@ -39,8 +39,6 @@ def char_bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and """Calculate BLEU. BLEU compares reference text and
hypothesis text in char-level using scarebleu. hypothesis text in char-level using scarebleu.
:param reference: The reference sentences. :param reference: The reference sentences.
:type reference: list[list[str]] :type reference: list[list[str]]
:param hypothesis: The hypothesis sentence. :param hypothesis: The hypothesis sentence.
@ -52,3 +50,70 @@ def char_bleu(hypothesis, reference):
for ref in reference] for ref in reference]
return sacrebleu.corpus_bleu(hypothesis, reference) return sacrebleu.corpus_bleu(hypothesis, reference)
class ErrorCalculator():
"""Calculate BLEU for ST and MT models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list: vocabulary list
:param sym_space: space symbol
:param sym_pad: pad symbol
:param report_bleu: report BLUE score if True
"""
def __init__(self, char_list, sym_space, sym_pad, report_bleu=False):
"""Construct an ErrorCalculator object."""
super().__init__()
self.char_list = char_list
self.space = sym_space
self.pad = sym_pad
self.report_bleu = report_bleu
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad):
"""Calculate corpus-level BLEU score.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:return: corpus-level BLEU score in a mini-batch
:rtype float
"""
bleu = None
if not self.report_bleu:
return bleu
bleu = self.calculate_corpus_bleu(ys_hat, ys_pad)
return bleu
def calculate_corpus_bleu(self, ys_hat, ys_pad):
"""Calculate corpus-level BLEU score in a mini-batch.
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
:param torch.Tensor seqs_true: reference (batch, seqlen)
:return: corpus-level BLEU score
:rtype float
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.pad, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true],
seqs_hat)
return bleu * 100

@ -0,0 +1,20 @@
import inspect
def check_kwargs(func, kwargs, name=None):
"""check kwargs are valid for func
If kwargs are invalid, raise TypeError as same as python default
:param function func: function to be validated
:param dict kwargs: keyword arguments for func
:param str name: name used in TypeError (default is func name)
"""
try:
params = inspect.signature(func).parameters
except ValueError:
return
if name is None:
name = func.__name__
for k in kwargs.keys():
if k not in params:
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'")

@ -0,0 +1,241 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import sys
import h5py
import kaldiio
import soundfile
from deepspeech.io.reader import SoundHDF5File
def file_reader_helper(
rspecifier: str,
filetype: str="mat",
return_shape: bool=False,
segments: str=None, ):
"""Read uttid and array in kaldi style
This function might be a bit confusing as "ark" is used
for HDF5 to imitate "kaldi-rspecifier".
Args:
rspecifier: Give as "ark:feats.ark" or "scp:feats.scp"
filetype: "mat" is kaldi-martix, "hdf5": HDF5
return_shape: Return the shape of the matrix,
instead of the matrix. This can reduce IO cost for HDF5.
segments (str): The file format is
"<segment-id> <recording-id> <start-time> <end-time>\n"
"e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5\n"
Returns:
Generator[Tuple[str, np.ndarray], None, None]:
Examples:
Read from kaldi-matrix ark file:
>>> for u, array in file_reader_helper('ark:feats.ark', 'mat'):
... array
Read from HDF5 file:
>>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'):
... array
"""
if filetype == "mat":
return KaldiReader(
rspecifier, return_shape=return_shape, segments=segments)
elif filetype == "hdf5":
return HDF5Reader(rspecifier, return_shape=return_shape)
elif filetype == "sound.hdf5":
return SoundHDF5Reader(rspecifier, return_shape=return_shape)
elif filetype == "sound":
return SoundReader(rspecifier, return_shape=return_shape)
else:
raise NotImplementedError(f"filetype={filetype}")
class KaldiReader:
def __init__(self, rspecifier, return_shape=False, segments=None):
self.rspecifier = rspecifier
self.return_shape = return_shape
self.segments = segments
def __iter__(self):
with kaldiio.ReadHelper(
self.rspecifier, segments=self.segments) as reader:
for key, array in reader:
if self.return_shape:
array = array.shape
yield key, array
class HDF5Reader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'.
format(self.rspecifier))
self.rspecifier = rspecifier
self.ark_or_scp, self.filepath = self.rspecifier.split(":", 1)
if self.ark_or_scp not in ["ark", "scp"]:
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
self.return_shape = return_shape
def __iter__(self):
if self.ark_or_scp == "scp":
hdf5_dict = {}
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ":" not in value:
raise RuntimeError(
"scp file for hdf5 should be like: "
'"uttid filepath.h5:key": {}({})'.format(
line, self.filepath))
path, h5_key = value.split(":", 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = h5py.File(path, "r")
except Exception:
logging.error("Error when loading {}".format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error("Error when loading {} with key={}".
format(path, h5_key))
raise
if self.return_shape:
yield key, data.shape
else:
yield key, data[()]
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == "-":
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
with h5py.File(filepath, "r") as f:
for key in f:
if self.return_shape:
yield key, f[key].shape
else:
yield key, f[key][()]
class SoundHDF5Reader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'.
format(rspecifier))
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
if self.ark_or_scp not in ["ark", "scp"]:
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
self.return_shape = return_shape
def __iter__(self):
if self.ark_or_scp == "scp":
hdf5_dict = {}
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ":" not in value:
raise RuntimeError(
"scp file for hdf5 should be like: "
'"uttid filepath.h5:key": {}({})'.format(
line, self.filepath))
path, h5_key = value.split(":", 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = SoundHDF5File(path, "r")
except Exception:
logging.error("Error when loading {}".format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error("Error when loading {} with key={}".
format(path, h5_key))
raise
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
array, rate = data
if self.return_shape:
array = array.shape
yield key, (rate, array)
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == "-":
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
for key, (a, r) in SoundHDF5File(filepath, "r").items():
if self.return_shape:
a = a.shape
yield key, (r, a)
class SoundReader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError('Give "rspecifier" such as "scp:some.scp: {}"'.
format(rspecifier))
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
if self.ark_or_scp != "scp":
raise ValueError('Only supporting "scp" for sound file: {}'.format(
self.ark_or_scp))
self.return_shape = return_shape
def __iter__(self):
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, sound_file_path = line.rstrip().split(None, 1)
# Assume PCM16
array, rate = soundfile.read(sound_file_path, dtype="int16")
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
if self.return_shape:
array = array.shape
yield key, (rate, array)

@ -0,0 +1,70 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections.abc import Sequence
from distutils.util import strtobool as dist_strtobool
import numpy
def strtobool(x):
# distutils.util.strtobool returns integer, but it's confusing,
return bool(dist_strtobool(x))
def get_commandline_args():
extra_chars = [
" ",
";",
"&",
"(",
")",
"|",
"^",
"<",
">",
"?",
"*",
"[",
"]",
"$",
"`",
'"',
"\\",
"!",
"{",
"}",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''") if all(char not in arg
for char in extra_chars) else
"'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv
]
return sys.executable + " " + " ".join(argv)
def is_scipy_wav_style(value):
# If Tuple[int, numpy.ndarray] or not
return (isinstance(value, Sequence) and len(value) == 2 and
isinstance(value[0], int) and isinstance(value[1], numpy.ndarray))
def assert_scipy_wav_style(value):
assert is_scipy_wav_style(
value), "Must be Tuple[int, numpy.ndarray], but got {}".format(
type(value) if not isinstance(value, Sequence) else "{}[{}]".format(
type(value), ", ".join(str(type(v)) for v in value)))

@ -0,0 +1,293 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict
import h5py
import kaldiio
import numpy
import soundfile
from deepspeech.io.reader import SoundHDF5File
from deepspeech.utils.cli_utils import assert_scipy_wav_style
def file_writer_helper(
wspecifier: str,
filetype: str="mat",
write_num_frames: str=None,
compress: bool=False,
compression_method: int=2,
pcm_format: str="wav", ):
"""Write matrices in kaldi style
Args:
wspecifier: e.g. ark,scp:out.ark,out.scp
filetype: "mat" is kaldi-martix, "hdf5": HDF5
write_num_frames: e.g. 'ark,t:num_frames.txt'
compress: Compress or not
compression_method: Specify compression level
Write in kaldi-matrix-ark with "kaldi-scp" file:
>>> with file_writer_helper('ark,scp:out.ark,out.scp') as f:
>>> f['uttid'] = array
This "scp" has the following format:
uttidA out.ark:1234
uttidB out.ark:2222
where, 1234 and 2222 points the strating byte address of the matrix.
(For detail, see official documentation of Kaldi)
Write in HDF5 with "scp" file:
>>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f:
>>> f['uttid'] = array
This "scp" file is created as:
uttidA out.h5:uttidA
uttidB out.h5:uttidB
HDF5 can be, unlike "kaldi-ark", accessed to any keys,
so originally "scp" is not required for random-reading.
Nevertheless we create "scp" for HDF5 because it is useful
for some use-case. e.g. Concatenation, Splitting.
"""
if filetype == "mat":
return KaldiWriter(
wspecifier,
write_num_frames=write_num_frames,
compress=compress,
compression_method=compression_method, )
elif filetype == "hdf5":
return HDF5Writer(
wspecifier, write_num_frames=write_num_frames, compress=compress)
elif filetype == "sound.hdf5":
return SoundHDF5Writer(
wspecifier,
write_num_frames=write_num_frames,
pcm_format=pcm_format)
elif filetype == "sound":
return SoundWriter(
wspecifier,
write_num_frames=write_num_frames,
pcm_format=pcm_format)
else:
raise NotImplementedError(f"filetype={filetype}")
class BaseWriter:
def __setitem__(self, key, value):
raise NotImplementedError
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
try:
self.writer.close()
except Exception:
pass
if self.writer_scp is not None:
try:
self.writer_scp.close()
except Exception:
pass
if self.writer_nframe is not None:
try:
self.writer_nframe.close()
except Exception:
pass
def get_num_frames_writer(write_num_frames: str):
"""get_num_frames_writer
Examples:
>>> get_num_frames_writer('ark,t:num_frames.txt')
"""
if write_num_frames is not None:
if ":" not in write_num_frames:
raise ValueError('Must include ":", write_num_frames={}'.format(
write_num_frames))
nframes_type, nframes_file = write_num_frames.split(":", 1)
if nframes_type != "ark,t":
raise ValueError("Only supporting text mode. "
"e.g. --write-num-frames=ark,t:foo.txt :"
"{}".format(nframes_type))
return open(nframes_file, "w", encoding="utf-8")
class KaldiWriter(BaseWriter):
def __init__(self,
wspecifier,
write_num_frames=None,
compress=False,
compression_method=2):
if compress:
self.writer = kaldiio.WriteHelper(
wspecifier, compression_method=compression_method)
else:
self.writer = kaldiio.WriteHelper(wspecifier)
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
self.writer[key] = value
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value)}\n")
def parse_wspecifier(wspecifier: str) -> Dict[str, str]:
"""Parse wspecifier to dict
Examples:
>>> parse_wspecifier('ark,scp:out.ark,out.scp')
{'ark': 'out.ark', 'scp': 'out.scp'}
"""
ark_scp, filepath = wspecifier.split(":", 1)
if ark_scp not in ["ark", "scp,ark", "ark,scp"]:
raise ValueError("{} is not allowed: {}".format(ark_scp, wspecifier))
ark_scps = ark_scp.split(",")
filepaths = filepath.split(",")
if len(ark_scps) != len(filepaths):
raise ValueError("Mismatch: {} and {}".format(ark_scp, filepath))
spec_dict = dict(zip(ark_scps, filepaths))
return spec_dict
class HDF5Writer(BaseWriter):
"""HDF5Writer
Examples:
>>> with HDF5Writer('ark:out.h5', compress=True) as f:
... f['key'] = array
"""
def __init__(self, wspecifier, write_num_frames=None, compress=False):
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict["ark"]
if compress:
self.kwargs = {"compression": "gzip"}
else:
self.kwargs = {}
self.writer = h5py.File(spec_dict["ark"], "w")
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
self.writer.create_dataset(key, data=value, **self.kwargs)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value)}\n")
class SoundHDF5Writer(BaseWriter):
"""SoundHDF5Writer
Examples:
>>> fs = 16000
>>> with SoundHDF5Writer('ark:out.h5') as f:
... f['key'] = fs, array
"""
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict["ark"]
self.writer = SoundHDF5File(
spec_dict["ark"], "w", format=self.pcm_format)
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
# Change Tuple[int, ndarray] -> Tuple[ndarray, int]
# (scipy style -> soundfile style)
value = (value[1], value[0])
self.writer.create_dataset(key, data=value)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value[0])}\n")
class SoundWriter(BaseWriter):
"""SoundWriter
Examples:
>>> fs = 16000
>>> with SoundWriter('ark,scp:outdir,out.scp') as f:
... f['key'] = fs, array
"""
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
# e.g. ark,scp:dirname,wav.scp
# -> The wave files are found in dirname/*.wav
self.dirname = spec_dict["ark"]
Path(self.dirname).mkdir(parents=True, exist_ok=True)
self.writer = None
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
rate, signal = value
wavfile = Path(self.dirname) / (key + "." + self.pcm_format)
soundfile.write(wavfile, signal.astype(numpy.int16), rate)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {wavfile}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(signal)}\n")

@ -14,12 +14,12 @@
"""This module provides functions to calculate error rate in different level. """This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
from itertools import groupby
import editdistance import editdistance
import numpy as np import numpy as np
__all__ = ['word_errors', 'char_errors', 'wer', 'cer'] __all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"]
editdistance.eval("a", "b")
def _levenshtein_distance(ref, hyp): def _levenshtein_distance(ref, hyp):
@ -211,3 +211,154 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False):
cer = float(edit_distance) / ref_len cer = float(edit_distance) / ref_len
return cer return cer
class ErrorCalculator():
"""Calculate CER and WER for E2E_ASR and CTC models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list: List[str]
:param sym_space: <space>
:param sym_blank: <blank>
:return:
"""
def __init__(self,
char_list,
sym_space,
sym_blank,
report_cer=False,
report_wer=False):
"""Construct an ErrorCalculator object."""
super().__init__()
self.report_cer = report_cer
self.report_wer = report_wer
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
self.idx_blank = self.char_list.index(self.blank)
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad, is_ctc=False):
"""Calculate sentence-level WER/CER score.
:param paddle.Tensor ys_hat: prediction (batch, seqlen)
:param paddle.Tensor ys_pad: reference (batch, seqlen)
:param bool is_ctc: calculate CER score for CTC
:return: sentence-level WER score
:rtype float
:return: sentence-level CER score
:rtype float
"""
cer, wer = None, None
if is_ctc:
return self.calculate_cer_ctc(ys_hat, ys_pad)
elif not self.report_cer and not self.report_wer:
return cer, wer
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
def calculate_cer_ctc(self, ys_hat, ys_pad):
"""Calculate sentence-level CER score for CTC.
:param paddle.Tensor ys_hat: prediction (batch, seqlen)
:param paddle.Tensor ys_pad: reference (batch, seqlen)
:return: average sentence-level CER score
:rtype float
"""
cers, char_ref_lens = [], []
for i, y in enumerate(ys_hat):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat, seq_true = [], []
for idx in y_hat:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_hat.append(self.char_list[int(idx)])
for idx in y_true:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_true.append(self.char_list[int(idx)])
hyp_chars = "".join(seq_hat)
ref_chars = "".join(seq_true)
if len(ref_chars) > 0:
cers.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
return cer_ctc
def convert_to_char(self, ys_hat, ys_pad):
"""Convert index to character.
:param paddle.Tensor seqs_hat: prediction (batch, seqlen)
:param paddle.Tensor seqs_true: reference (batch, seqlen)
:return: token list of prediction
:rtype list
:return: token list of reference
:rtype list
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
def calculate_cer(self, seqs_hat, seqs_true):
"""Calculate sentence-level CER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level CER score
:rtype float
"""
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
def calculate_wer(self, seqs_hat, seqs_true):
"""Calculate sentence-level WER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level WER score
:rtype float
"""
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)

@ -0,0 +1,4 @@
dump
fbank
exp
data

@ -1,8 +1,11 @@
# LibriSpeech # LibriSpeech
| Model | Params | Config | Augmentation| Loss |
| --- | --- | --- | --- | ## Transformer
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 |
| Model | Params | GPUS | Averaged Model | Config | Augmentation| Loss |
| --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | 8 Tesla V100-SXM2-32GB | 10-best val_loss | conf/transformer.yaml | spec_aug | 6.3197922706604 |
| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | | Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
@ -11,4 +14,14 @@
| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | | test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 |
| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | | test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 |
| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | | test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 |
### JoinCTC
| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| test-clean | join_ctc_only_att | 2620 | 52576 | 96.1 | 2.5 | 1.4 | 0.4 | 4.4 | 34.7 |
| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | | test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 |
| test-clean | join_ctc_w_lm | 2620 | 52576 | 97.9 | 1.8 | 0.2 | 0.3 | 2.4 | 27.8 |
Compare with [ESPNET](https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/RESULTS.md#pytorch-large-transformer-with-specaug-4-gpus--transformer-lm-4-gpus)
we using 8gpu, but model size (aheads4-adim256) small than it.

@ -1,122 +0,0 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: True
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 240
accum_grad: 8
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: true # simulate streaming inference. Defaults to False.

@ -1,115 +0,0 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: transformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 1
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: true # simulate streaming inference. Defaults to False.

@ -1,118 +0,0 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.5 # seconds
max_input_len: 20.0 # seconds
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 120
accum_grad: 8
global_grad_clip: 3.0
optim: adam
optim_conf:
lr: 0.004
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -0,0 +1,2 @@
--sample-frequency=16000
--num-mel-bins=80

@ -0,0 +1,13 @@
model_module: transformer
model:
n_vocab: 5002
pos_enc: null
embed_unit: 128
att_unit: 512
head: 8
unit: 2048
layer: 16
dropout_rate: 0.5
emb_dropout_rate: 0.0
att_dropout_rate: 0.0
tie_weights: False

@ -0,0 +1 @@
--sample-frequency=16000

@ -5,9 +5,9 @@ data:
test_manifest: data/manifest.test-clean test_manifest: data/manifest.test-clean
collator: collator:
vocab_filepath: data/bpe_unigram_5000_units.txt vocab_filepath: data/lang_char/train_960_unigram5000_units.txt
unit_type: spm unit_type: spm
spm_model_prefix: data/bpe_unigram_5000 spm_model_prefix: data/lang_char/train_960_unigram5000
feat_dim: 83 feat_dim: 83
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0

@ -2,19 +2,42 @@
stage=-1 stage=-1
stop_stage=100 stop_stage=100
nj=32
debugmode=1
dumpdir=dump # directory to dump full features
N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
verbose=0 # verbose option
resume= # Resume the training from snapshot
# feature configuration
do_delta=false
# Set this to somewhere where you want to put your data, or where
# someone else has already put it. You'll want to change this
# if you're not on the CLSP grid.
datadir=${MAIN_ROOT}/examples/dataset/
# bpemode (unigram or bpe) # bpemode (unigram or bpe)
nbpe=5000 nbpe=5000
bpemode=unigram bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
source ${MAIN_ROOT}/utils/parse_options.sh source ${MAIN_ROOT}/utils/parse_options.sh
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train_960
train_sp=train_sp
train_dev=dev
recog_set="test_clean test_other dev_clean dev_other"
mkdir -p data mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR} mkdir -p ${TARGET_DIR}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests # download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \ python3 ${TARGET_DIR}/librispeech/librispeech.py \
@ -46,63 +69,98 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
fi fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# compute mean and stddev for normalizer ### Task dependent. You have to make data the following preparation part by yourself.
num_workers=$(nproc) ### But you can utilize Kaldi recipes in most cases
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ echo "stage 0: Data preparation"
--manifest_path="data/manifest.train.raw" \ for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
--num_samples=-1 \ # use underscore-separated names in data directories.
--spectrum_type="fbank" \ local/data_prep.sh ${datadir}/librispeech/${part}/LibriSpeech/${part} data/${part//-/_}
--feat_dim=80 \ done
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--use_dB_normalization=False \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi fi
feat_tr_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${feat_tr_dir}
feat_sp_dir=${dumpdir}/${train_sp}/delta${do_delta}; mkdir -p ${feat_sp_dir}
feat_dt_dir=${dumpdir}/${train_dev}/delta${do_delta}; mkdir -p ${feat_dt_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# build vocabulary ### Task dependent. You have to design training and dev sets by yourself.
python3 ${MAIN_ROOT}/utils/build_vocab.py \ ### But you can utilize Kaldi recipes in most cases
--unit_type "spm" \ echo "stage 1: Feature Generation"
--spm_vocab_size=${nbpe} \ fbankdir=fbank
--spm_mode ${bpemode} \ # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame
--spm_model_prefix ${bpeprefix} \ for x in dev_clean test_clean dev_other test_other train_clean_100 train_clean_360 train_other_500; do
--vocab_path="data/vocab.txt" \ steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj ${nj} --write_utt2num_frames true \
--manifest_paths="data/manifest.train.raw" data/${x} exp/make_fbank/${x} ${fbankdir}
utils/fix_data_dir.sh data/${x}
done
if [ $? -ne 0 ]; then utils/combine_data.sh --extra_files utt2num_frames data/${train_set}_org data/train_clean_100 data/train_clean_360 data/train_other_500
echo "Build vocabulary failed. Terminated." utils/combine_data.sh --extra_files utt2num_frames data/${train_dev}_org data/dev_clean data/dev_other
exit 1 utils/perturb_data_dir_speed.sh 0.9 data/${train_set}_org data/temp1
fi utils/perturb_data_dir_speed.sh 1.0 data/${train_set}_org data/temp2
utils/perturb_data_dir_speed.sh 1.1 data/${train_set}_org data/temp3
utils/combine_data.sh --extra-files utt2uniq data/${train_sp}_org data/temp1 data/temp2 data/temp3
# remove utt having more than 3000 frames
# remove utt having more than 400 characters
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_set}_org data/${train_set}
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_sp}_org data/${train_sp}
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_dev}_org data/${train_dev}
steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj $nj --write_utt2num_frames true \
data/train_sp exp/make_fbank/train_sp ${fbankdir}
utils/fix_data_dir.sh data/train_sp
# compute global CMVN
compute-cmvn-stats scp:data/${train_sp}/feats.scp data/${train_sp}/cmvn.ark
# dump features for training
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${train_sp}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/train ${feat_sp_dir}
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${train_dev}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/dev ${feat_dt_dir}
for rtask in ${recog_set}; do
feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir}
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${rtask}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/recog/${rtask} \
${feat_recog_dir}
done
fi fi
dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
bpemodel=data/lang_char/${train_set}_${bpemode}${nbpe}
echo "dictionary: ${dict}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size ### Task dependent. You have to check non-linguistic symbols used in the corpus.
for set in train dev test dev-clean dev-other test-clean test-other; do echo "stage 2: Dictionary and Json Data Preparation"
{ mkdir -p data/lang_char/
python3 ${MAIN_ROOT}/utils/format_data.py \ echo "<unk> 1" > ${dict} # <unk> must be 1, 0 will be used for "blank" in CTC
--feat_type "raw" \ cut -f 2- -d" " data/${train_set}/text > data/lang_char/input.txt
--cmvn_path "data/mean_std.json" \ spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
--unit_type "spm" \ spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+1}' >> ${dict}
--spm_model_prefix ${bpeprefix} \ wc -l ${dict}
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \ # make json labels
--output_path="data/manifest.${set}" data2json.sh --nj ${nj} --feat ${feat_sp_dir}/feats.scp --bpecode ${bpemodel}.model \
data/${train_sp} ${dict} > ${feat_sp_dir}/data_${bpemode}${nbpe}.json
if [ $? -ne 0 ]; then data2json.sh --nj ${nj} --feat ${feat_dt_dir}/feats.scp --bpecode ${bpemodel}.model \
echo "Formt mnaifest failed. Terminated." data/${train_dev} ${dict} > ${feat_dt_dir}/data_${bpemode}${nbpe}.json
exit 1
fi for rtask in ${recog_set}; do
}& feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
data2json.sh --nj ${nj} --feat ${feat_recog_dir}/feats.scp --bpecode ${bpemodel}.model \
data/${rtask} ${dict} > ${feat_recog_dir}/data_${bpemode}${nbpe}.json
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# make json labels
python3 local/espnet_json_to_manifest.py --json-file ${feat_sp_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.train
python3 local/espnet_json_to_manifest.py --json-file ${feat_dt_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.dev
for rtask in ${recog_set}; do
feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
python3 local/espnet_json_to_manifest.py --json-file ${feat_recog_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.${rtask//_/-}
done done
wait
fi fi
echo "LibriSpeech Data preparation done." echo "LibriSpeech Data preparation done."

@ -0,0 +1,85 @@
#!/usr/bin/env bash
# Copyright 2014 Vassil Panayotov
# 2014 Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <src-dir> <dst-dir>"
echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean"
exit 1
fi
src=$1
dst=$2
# all utterances are FLAC compressed
if ! which flac >&/dev/null; then
echo "Please install 'flac' on ALL worker nodes!"
exit 1
fi
spk_file=$src/../SPEAKERS.TXT
mkdir -p $dst || exit 1
[ ! -d $src ] && echo "$0: no such directory $src" && exit 1
[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1
wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp
trans=$dst/text; [[ -f "$trans" ]] && rm $trans
utt2spk=$dst/utt2spk; [[ -f "$utt2spk" ]] && rm $utt2spk
spk2gender=$dst/spk2gender; [[ -f $spk2gender ]] && rm $spk2gender
for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do
reader=$(basename $reader_dir)
if ! [ $reader -eq $reader ]; then # not integer.
echo "$0: unexpected subdirectory name $reader"
exit 1
fi
reader_gender=$(egrep "^$reader[ ]+\|" $spk_file | awk -F'|' '{gsub(/[ ]+/, ""); print tolower($2)}')
if [ "$reader_gender" != 'm' ] && [ "$reader_gender" != 'f' ]; then
echo "Unexpected gender: '$reader_gender'"
exit 1
fi
for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do
chapter=$(basename $chapter_dir)
if ! [ "$chapter" -eq "$chapter" ]; then
echo "$0: unexpected chapter-subdirectory name $chapter"
exit 1
fi
find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \
awk -v "dir=$chapter_dir" '{printf "%s flac -c -d -s %s/%s.flac |\n", $0, dir, $0}' >>$wav_scp|| exit 1
chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt
[ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1
cat $chapter_trans >>$trans
# NOTE: For now we are using per-chapter utt2spk. That is each chapter is considered
# to be a different speaker. This is done for simplicity and because we want
# e.g. the CMVN to be calculated per-chapter
awk -v "reader=$reader" -v "chapter=$chapter" '{printf "%s %s-%s\n", $1, reader, chapter}' \
<$chapter_trans >>$utt2spk || exit 1
# reader -> gender map (again using per-chapter granularity)
echo "${reader}-${chapter} $reader_gender" >>$spk2gender
done
done
spk2utt=$dst/spk2utt
utils/utt2spk_to_spk2utt.pl <$utt2spk >$spk2utt || exit 1
ntrans=$(wc -l <$trans)
nutt2spk=$(wc -l <$utt2spk)
! [ "$ntrans" -eq "$nutt2spk" ] && \
echo "Inconsistent #transcripts($ntrans) and #utt2spk($nutt2spk)" && exit 1
utils/validate_data_dir.sh --no-feats $dst || exit 1
echo "$0: successfully prepared data in $dst"
exit 0

@ -11,22 +11,24 @@ tag=
decode_config=conf/decode/decode.yaml decode_config=conf/decode/decode.yaml
# lm params # lm params
lang_model=rnnlm.model.best lang_model=transformerLM.pdparams
lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ lmexpdir=exp/lm/transformer
lmtag='nolm' rnnlm_config_path=conf/lm/transformer.yaml
lmtag='transformer'
train_set=train_960
recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean" recog_set="test-clean"
# bpemode (unigram or bpe) # bpemode (unigram or bpe)
nbpe=5000 nbpe=5000
bpemode=unigram bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}" bpeprefix=data/lang_char/${train_set}_${bpemode}${nbpe}
bpemodel=${bpeprefix}.model bpemodel=${bpeprefix}.model
# bin params # bin params
config_path=conf/transformer.yaml config_path=conf/transformer.yaml
dict=data/bpe_unigram_5000_units.txt dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
ckpt_prefix= ckpt_prefix=
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -90,9 +92,9 @@ for dmethd in join_ctc; do
--recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \ --recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \
--result-label ${decode_dir}/data.JOB.json \ --result-label ${decode_dir}/data.JOB.json \
--model-conf ${config_path} \ --model-conf ${config_path} \
--model ${ckpt_prefix}.pdparams --model ${ckpt_prefix}.pdparams \
--rnnlm-conf ${rnnlm_config_path} \
#--rnnlm ${lmexpdir}/${lang_model} \ --rnnlm ${lmexpdir}/${lang_model}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}

@ -8,17 +8,18 @@ nj=32
lmtag='nolm' lmtag='nolm'
train_set=train_960
recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean" recog_set="test-clean"
# bpemode (unigram or bpe) # bpemode (unigram or bpe)
nbpe=5000 nbpe=5000
bpemode=unigram bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}" bpeprefix=data/lang_char/${train_set}_${bpemode}${nbpe}
bpemodel=${bpeprefix}.model bpemodel=${bpeprefix}.model
config_path=conf/transformer.yaml config_path=conf/transformer.yaml
dict=data/bpe_unigram_5000_units.txt dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
ckpt_prefix= ckpt_prefix=
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;

@ -1,6 +1,6 @@
export MAIN_ROOT=`realpath ${PWD}/../../../` export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH} export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${MAIN_ROOT}/utils:${PWD}/utils:${PATH}
export LC_ALL=C export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1 export PYTHONDONTWRITEBYTECODE=1
@ -13,3 +13,16 @@ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2_kaldi MODEL=u2_kaldi
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
# srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
export SRILM=${MAIN_ROOT}/tools/srilm
export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64
# Kaldi
export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!"
[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh

@ -33,16 +33,24 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # attetion resocre decoder
./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] && ${use_lm} == true; then
# join ctc decoder, use transformerlm to score
if [ ! -f exp/lm/transformer/transformerLM.pdparams ]; then
wget https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams exp/lm/transformer/
fi
bash local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt}
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# ctc alignment of test data # ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi

@ -0,0 +1 @@
../../../tools/kaldi/egs/wsj/s5/steps/

@ -1 +1 @@
../../../utils/ ../../../tools/kaldi/egs/wsj/s5/utils

@ -23,6 +23,7 @@ praatio~=4.1
pre-commit pre-commit
pybind11 pybind11
pypinyin pypinyin
python-dateutil
pyworld pyworld
resampy==0.2.2 resampy==0.2.2
sacrebleu sacrebleu
@ -41,3 +42,4 @@ visualdl==2.2.0
webrtcvad webrtcvad
yacs yacs
yq yq
nara_wpe

@ -65,13 +65,6 @@ def _remove(files: str):
def _post_install(install_lib_dir): def _post_install(install_lib_dir):
# apt
check_call("apt-get update -y")
check_call("apt-get install -y " + 'vim tig tree sox pkg-config ' +
'libsndfile1 libflac-dev libogg-dev ' +
'libvorbis-dev libboost-dev swig python3-dev ')
print("apt install.")
# tools/make # tools/make
tool_dir = HERE / "tools" tool_dir = HERE / "tools"
_remove(tool_dir.glob("*.done")) _remove(tool_dir.glob("*.done"))

@ -10,7 +10,7 @@ fi
if [ -e /etc/lsb-release ];then if [ -e /etc/lsb-release ];then
${SUDO} apt-get update -y ${SUDO} apt-get update -y
${SUDO} apt-get install -y jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev ${SUDO} apt-get install -y bc jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
if [ $? != 0 ]; then if [ $? != 0 ]; then
error_msg "Please using Ubuntu or install pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev by user." error_msg "Please using Ubuntu or install pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev by user."
exit -1 exit -1

@ -10,7 +10,7 @@ WGET ?= wget --no-check-certificate
.PHONY: all clean .PHONY: all clean
all: virtualenv.done kenlm.done sox.done soxbindings.done mfa.done sclite.done all: virtualenv.done apt.done kenlm.done sox.done soxbindings.done mfa.done sclite.done
virtualenv.done: virtualenv.done:
test -d venv || virtualenv -p $(PYTHON) venv test -d venv || virtualenv -p $(PYTHON) venv
@ -21,6 +21,13 @@ clean:
find -iname "*.pyc" -delete find -iname "*.pyc" -delete
rm -rf kenlm rm -rf kenlm
apt.done:
apt update -y
apt install -y bc flac jq vim tig tree pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
echo "check_certificate = off" >> ~/.wgetrc
touch apt.done
kenlm.done: kenlm.done:
# Ubuntu 16.04 透過 apt 會安裝 boost 1.58.0 # Ubuntu 16.04 透過 apt 會安裝 boost 1.58.0
# it seems that boost (1.54.0) requires higher version. After I switched to g++-5 it compiles normally. # it seems that boost (1.54.0) requires higher version. After I switched to g++-5 it compiles normally.
@ -48,6 +55,13 @@ mfa.done:
tar xvf montreal-forced-aligner_linux.tar.gz tar xvf montreal-forced-aligner_linux.tar.gz
touch mfa.done touch mfa.done
openblas.done:
bash extras/install_openblas.sh
touch openblas.done
kaldi.done: openblas.done
bash extras/install_kaldi.sh
touch kaldi.done
#== SCTK =============================================================================== #== SCTK ===============================================================================
# SCTK official repo does not have version tags. Here's the mapping: # SCTK official repo does not have version tags. Here's the mapping:

@ -16,7 +16,7 @@ else
echo "$KALDI_DIR already exists!" echo "$KALDI_DIR already exists!"
fi fi
cd "$KALDI_DIR/tools" pushd "$KALDI_DIR/tools"
git pull git pull
# Prevent kaldi from switching default python version # Prevent kaldi from switching default python version
@ -28,8 +28,12 @@ touch "python/.use_default_python"
make -j4 make -j4
pushd ../src pushd ../src
./configure --shared --use-cuda=no --static-math --mathlib=OPENBLAS --openblas-root=${KALDI_DIR}/../OpenBLAS/install OPENBLAS_DIR=${KALDI_DIR}/../OpenBLAS
mkdir -p ${OPENBLAS_DIR}/install
./configure --shared --use-cuda=no --static-math --mathlib=OPENBLAS --openblas-root=${OPENBLAS_DIR}/install
make clean -j && make depend -j && make -j4 make clean -j && make depend -j && make -j4
popd popd
popd
echo "Done installing Kaldi." echo "Done installing Kaldi."

@ -0,0 +1,149 @@
#!/usr/bin/env python3
import argparse
import logging
from distutils.util import strtobool
import kaldiio
import numpy
from deepspeech.transform.cmvn import CMVN
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="apply mean-variance normalization to files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--stats-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "npy"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--norm-means",
type=strtobool,
default=True,
help="Do variance normalization or not.", )
parser.add_argument(
"--norm-vars",
type=strtobool,
default=False,
help="Do variance normalization or not.", )
parser.add_argument(
"--reverse",
type=strtobool,
default=False,
help="Do reverse mode or not")
parser.add_argument(
"--spk2utt",
type=str,
help="A text file of speaker to utterance-list map. "
"(Don't give rspecifier format, such as "
'"ark:spk2utt")', )
parser.add_argument(
"--utt2spk",
type=str,
help="A text file of utterance to speaker map. "
"(Don't give rspecifier format, such as "
'"ark:utt2spk")', )
parser.add_argument(
"--write-num-frames",
type=str,
help="Specify wspecifer for utt2num_frames")
parser.add_argument(
"--compress",
type=strtobool,
default=False,
help="Save in compressed format")
parser.add_argument(
"--compression-method",
type=int,
default=2,
help="Specify the method(if mat) or "
"gzip-level(if hdf5)", )
parser.add_argument(
"stats_rspecifier_or_rxfilename",
help="Input stats. e.g. ark:stats.ark or stats.mat", )
parser.add_argument(
"rspecifier", type=str, help="Read specifier id. e.g. ark:some.ark")
parser.add_argument(
"wspecifier", type=str, help="Write specifier id. e.g. ark:some.ark")
return parser
def main():
args = get_parser().parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
if ":" in args.stats_rspecifier_or_rxfilename:
is_rspcifier = True
if args.stats_filetype == "npy":
stats_filetype = "hdf5"
else:
stats_filetype = args.stats_filetype
stats_dict = dict(
file_reader_helper(args.stats_rspecifier_or_rxfilename,
stats_filetype))
else:
is_rspcifier = False
if args.stats_filetype == "mat":
stats = kaldiio.load_mat(args.stats_rspecifier_or_rxfilename)
else:
stats = numpy.load(args.stats_rspecifier_or_rxfilename)
stats_dict = {None: stats}
cmvn = CMVN(
stats=stats_dict,
norm_means=args.norm_means,
norm_vars=args.norm_vars,
utt2spk=args.utt2spk,
spk2utt=args.spk2utt,
reverse=args.reverse, )
with file_writer_helper(
args.wspecifier,
filetype=args.out_filetype,
write_num_frames=args.write_num_frames,
compress=args.compress,
compression_method=args.compression_method, ) as writer:
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]
rate, mat = mat
mat = cmvn(mat, utt if is_rspcifier else None)
writer[utt] = mat
if __name__ == "__main__":
main()

@ -47,8 +47,10 @@ def main(args):
beat_val_scores = sorted_val_scores[:args.num, 1] beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
avg_val_score = np.mean(beat_val_scores)
print("selected val scores = " + str(beat_val_scores)) print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs)) print("selected epochs = " + str(selected_epochs))
print("averaged val score = " + str(avg_val_score))
path_list = [ path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
@ -80,7 +82,7 @@ def main(args):
data = json.dumps({ data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest', "mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model, "avg_ckpt": args.dst_model,
"val_loss_mean": np.mean(beat_val_scores), "val_loss_mean": avg_val_score,
"ckpts": path_list, "ckpts": path_list,
"epochs": selected_epochs.tolist(), "epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(), "val_losses": beat_val_scores.tolist(),

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2021 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import glob
import os
from dateutil import parser
def get_parser():
parser = argparse.ArgumentParser(
description="calculate real time factor (RTF)")
parser.add_argument(
"--log-dir",
type=str,
default=None,
help="path to logging directory", )
return parser
def main():
args = get_parser().parse_args()
audio_sec = 0
decode_sec = 0
n_utt = 0
audio_durations = []
start_times = []
end_times = []
for x in glob.glob(os.path.join(args.log_dir, "decode.*.log")):
with codecs.open(x, "r", "utf-8") as f:
for line in f:
x = line.strip()
# 2021-10-25 08:22:04.052 | INFO | xxx:recog_v2:188 - feat: (1570, 83)
if "feat:" in x:
dur = int(x.split("(")[1].split(',')[0])
audio_durations += [dur]
start_times += [parser.parse(x.split("|")[0])]
elif "total log probability:" in x:
end_times += [parser.parse(x.split("|")[0])]
assert len(audio_durations) == len(end_times), (len(audio_durations),
len(end_times), )
assert len(start_times) == len(end_times), (len(start_times),
len(end_times))
audio_sec += sum(audio_durations) / 100 # [sec]
decode_sec += sum([(end - start).total_seconds()
for start, end in zip(start_times, end_times)])
n_utt += len(audio_durations)
print("Total audio duration: %.3f [sec]" % audio_sec)
print("Total decoding time: %.3f [sec]" % decode_sec)
rtf = decode_sec / audio_sec if audio_sec > 0 else 0
print("RTF: %.3f" % rtf)
latency = decode_sec * 1000 / n_utt if n_utt > 0 else 0
print("Latency: %.3f [ms/sentence]" % latency)
if __name__ == "__main__":
main()

@ -0,0 +1,186 @@
#!/usr/bin/env python3
import argparse
import logging
import kaldiio
import numpy as np
from deepspeech.transform.transformation import Transformation
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="Compute cepstral mean and "
"variance normalization statistics"
"If wspecifier provided: per-utterance by default, "
"or per-speaker if"
"spk2utt option provided; if wxfilename: global",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--spk2utt",
type=str,
help="A text file of speaker to utterance-list map. "
"(Don't give rspecifier format, such as "
'"ark:utt2spk")', )
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "npy"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing", )
parser.add_argument(
"rspecifier",
type=str,
help="Read specifier for feats. e.g. ark:some.ark")
parser.add_argument(
"wspecifier_or_wxfilename",
type=str,
help="Write specifier. e.g. ark:some.ark")
return parser
def main():
args = get_parser().parse_args()
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
is_wspecifier = ":" in args.wspecifier_or_wxfilename
if is_wspecifier:
if args.spk2utt is not None:
logging.info("Performing as speaker CMVN mode")
utt2spk_dict = {}
with open(args.spk2utt) as f:
for line in f:
spk, utts = line.rstrip().split(None, 1)
for utt in utts.split():
utt2spk_dict[utt] = spk
def utt2spk(x):
return utt2spk_dict[x]
else:
logging.info("Performing as utterance CMVN mode")
def utt2spk(x):
return x
if args.out_filetype == "npy":
logging.warning("--out-filetype npy is allowed only for "
"Global CMVN mode, changing to hdf5")
args.out_filetype = "hdf5"
else:
logging.info("Performing as global CMVN mode")
if args.spk2utt is not None:
logging.warning("spk2utt is not used for global CMVN mode")
def utt2spk(x):
return None
if args.out_filetype == "hdf5":
logging.warning("--out-filetype hdf5 is not allowed for "
"Global CMVN mode, changing to npy")
args.out_filetype = "npy"
if args.preprocess_conf is not None:
preprocessing = Transformation(args.preprocess_conf)
logging.info("Apply preprocessing: {}".format(preprocessing))
else:
preprocessing = None
# Calculate stats for each speaker
counts = {}
sum_feats = {}
square_sum_feats = {}
idx = 0
for idx, (utt, matrix) in enumerate(
file_reader_helper(args.rspecifier, args.in_filetype), 1):
if is_scipy_wav_style(matrix):
# If data is sound file, then got as Tuple[int, ndarray]
rate, matrix = matrix
if preprocessing is not None:
matrix = preprocessing(matrix, uttid_list=utt)
spk = utt2spk(utt)
# Init at the first seen of the spk
if spk not in counts:
counts[spk] = 0
feat_shape = matrix.shape[1:]
# Accumulate in double precision
sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
square_sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
counts[spk] += matrix.shape[0]
sum_feats[spk] += matrix.sum(axis=0)
square_sum_feats[spk] += (matrix**2).sum(axis=0)
logging.info("Processed {} utterances".format(idx))
assert idx > 0, idx
cmvn_stats = {}
for spk in counts:
feat_shape = sum_feats[spk].shape
cmvn_shape = (2, feat_shape[0] + 1) + feat_shape[1:]
_cmvn_stats = np.empty(cmvn_shape, dtype=np.float64)
_cmvn_stats[0, :-1] = sum_feats[spk]
_cmvn_stats[1, :-1] = square_sum_feats[spk]
_cmvn_stats[0, -1] = counts[spk]
_cmvn_stats[1, -1] = 0.0
# You can get the mean and std as following,
# >>> N = _cmvn_stats[0, -1]
# >>> mean = _cmvn_stats[0, :-1] / N
# >>> std = np.sqrt(_cmvn_stats[1, :-1] / N - mean ** 2)
cmvn_stats[spk] = _cmvn_stats
# Per utterance or speaker CMVN
if is_wspecifier:
with file_writer_helper(
args.wspecifier_or_wxfilename,
filetype=args.out_filetype) as writer:
for spk, mat in cmvn_stats.items():
writer[spk] = mat
# Global CMVN
else:
matrix = cmvn_stats[None]
if args.out_filetype == "npy":
np.save(args.wspecifier_or_wxfilename, matrix)
elif args.out_filetype == "mat":
# Kaldi supports only matrix or vector
kaldiio.save_mat(args.wspecifier_or_wxfilename, matrix)
else:
raise RuntimeError(
"Not supporting: --out-filetype {}".format(args.out_filetype))
if __name__ == "__main__":
main()

@ -0,0 +1,104 @@
#!/usr/bin/env python3
import argparse
import logging
from distutils.util import strtobool
from deepspeech.transform.transformation import Transformation
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="copy feature with preprocessing",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi', )
parser.add_argument(
"--write-num-frames",
type=str,
help="Specify wspecifer for utt2num_frames")
parser.add_argument(
"--compress",
type=strtobool,
default=False,
help="Save in compressed format")
parser.add_argument(
"--compression-method",
type=int,
default=2,
help="Specify the method(if mat) or "
"gzip-level(if hdf5)", )
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing", )
parser.add_argument(
"rspecifier",
type=str,
help="Read specifier for feats. e.g. ark:some.ark")
parser.add_argument(
"wspecifier", type=str, help="Write specifier. e.g. ark:some.ark")
return parser
def main():
parser = get_parser()
args = parser.parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
if args.preprocess_conf is not None:
preprocessing = Transformation(args.preprocess_conf)
logging.info("Apply preprocessing: {}".format(preprocessing))
else:
preprocessing = None
with file_writer_helper(
args.wspecifier,
filetype=args.out_filetype,
write_num_frames=args.write_num_frames,
compress=args.compress,
compression_method=args.compression_method, ) as writer:
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]
rate, mat = mat
if preprocessing is not None:
mat = preprocessing(mat, uttid_list=utt)
# shape = (Time, Channel)
if args.out_filetype in ["sound.hdf5", "sound"]:
# Write Tuple[int, numpy.ndarray] (scipy style)
writer[utt] = (rate, mat)
else:
writer[utt] = mat
if __name__ == "__main__":
main()

@ -0,0 +1,170 @@
#!/usr/bin/env bash
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
echo "$0 $*" >&2 # Print the command line for logging
. ./path.sh
nj=1
cmd=run.pl
nlsyms=""
lang=""
feat="" # feat.scp
oov="<unk>"
bpecode=""
allow_one_column=false
verbose=0
trans_type=char
filetype=""
preprocess_conf=""
category=""
out="" # If omitted, write in stdout
text=""
multilingual=false
help_message=$(cat << EOF
Usage: $0 <data-dir> <dict>
e.g. $0 data/train data/lang_1char/train_units.txt
Options:
--nj <nj> # number of parallel jobs
--cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs.
--feat <feat-scp> # feat.scp or feat1.scp,feat2.scp,...
--oov <oov-word> # Default: <unk>
--out <outputfile> # If omitted, write in stdout
--filetype <mat|hdf5|sound.hdf5> # Specify the format of feats file
--preprocess-conf <json> # Apply preprocess to feats when creating shape.scp
--verbose <num> # Default: 0
EOF
)
. utils/parse_options.sh
if [ $# != 2 ]; then
echo "${help_message}" 1>&2
exit 1;
fi
set -euo pipefail
dir=$1
dic=$2
tmpdir=$(mktemp -d ${dir}/tmp-XXXXX)
trap 'rm -rf ${tmpdir}' EXIT
if [ -z ${text} ]; then
text=${dir}/text
fi
# 1. Create scp files for inputs
# These are not necessary for decoding mode, and make it as an option
input=
if [ -n "${feat}" ]; then
_feat_scps=$(echo "${feat}" | tr ',' ' ' )
read -r -a feat_scps <<< $_feat_scps
num_feats=${#feat_scps[@]}
for (( i=1; i<=num_feats; i++ )); do
feat=${feat_scps[$((i-1))]}
mkdir -p ${tmpdir}/input_${i}
input+="input_${i} "
cat ${feat} > ${tmpdir}/input_${i}/feat.scp
# Dump in the "legacy" style JSON format
if [ -n "${filetype}" ]; then
awk -v filetype=${filetype} '{print $1 " " filetype}' ${feat} \
> ${tmpdir}/input_${i}/filetype.scp
fi
feat_to_shape.sh --cmd "${cmd}" --nj ${nj} \
--filetype "${filetype}" \
--preprocess-conf "${preprocess_conf}" \
--verbose ${verbose} ${feat} ${tmpdir}/input_${i}/shape.scp
done
fi
# 2. Create scp files for outputs
mkdir -p ${tmpdir}/output
if [ -n "${bpecode}" ]; then
if [ ${multilingual} = true ]; then
# remove a space before the language ID
paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \
| spm_encode --model=${bpecode} --output_format=piece | cut -f 2- -d" ") \
> ${tmpdir}/output/token.scp
else
paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \
| spm_encode --model=${bpecode} --output_format=piece) \
> ${tmpdir}/output/token.scp
fi
elif [ -n "${nlsyms}" ]; then
text2token.py -s 1 -n 1 -l ${nlsyms} ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp
else
text2token.py -s 1 -n 1 ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp
fi
< ${tmpdir}/output/token.scp utils/sym2int.pl --map-oov ${oov} -f 2- ${dic} > ${tmpdir}/output/tokenid.scp
# +2 comes from CTC blank and EOS
vocsize=$(tail -n 1 ${dic} | awk '{print $2}')
odim=$(echo "$vocsize + 2" | bc)
< ${tmpdir}/output/tokenid.scp awk -v odim=${odim} '{print $1 " " NF-1 "," odim}' > ${tmpdir}/output/shape.scp
cat ${text} > ${tmpdir}/output/text.scp
# 3. Create scp files for the others
mkdir -p ${tmpdir}/other
if [ ${multilingual} == true ]; then
awk '{
n = split($1,S,"[-]");
lang=S[n];
print $1 " " lang
}' ${text} > ${tmpdir}/other/lang.scp
elif [ -n "${lang}" ]; then
awk -v lang=${lang} '{print $1 " " lang}' ${text} > ${tmpdir}/other/lang.scp
fi
if [ -n "${category}" ]; then
awk -v category=${category} '{print $1 " " category}' ${dir}/text \
> ${tmpdir}/other/category.scp
fi
cat ${dir}/utt2spk > ${tmpdir}/other/utt2spk.scp
# 4. Merge scp files into a JSON file
opts=""
if [ -n "${feat}" ]; then
intypes="${input} output other"
else
intypes="output other"
fi
for intype in ${intypes}; do
if [ -z "$(find "${tmpdir}/${intype}" -name "*.scp")" ]; then
continue
fi
if [ ${intype} != other ]; then
opts+="--${intype%_*}-scps "
else
opts+="--scps "
fi
for x in "${tmpdir}/${intype}"/*.scp; do
k=$(basename ${x} .scp)
if [ ${k} = shape ]; then
opts+="shape:${x}:shape "
else
opts+="${k}:${x} "
fi
done
done
if ${allow_one_column}; then
opts+="--allow-one-column true "
else
opts+="--allow-one-column false "
fi
if [ -n "${out}" ]; then
opts+="-O ${out}"
fi
merge_scp2json.py --verbose ${verbose} ${opts}
rm -fr ${tmpdir}

@ -0,0 +1,95 @@
#!/usr/bin/env bash
# Copyright 2017 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
echo "$0 $*" # Print the command line for logging
. ./path.sh
cmd=run.pl
do_delta=false
nj=1
verbose=0
compress=true
write_utt2num_frames=true
filetype='mat' # mat or hdf5
help_message="Usage: $0 <scp> <cmvnark> <logdir> <dumpdir>"
. utils/parse_options.sh
scp=$1
cvmnark=$2
logdir=$3
dumpdir=$4
if [ $# != 4 ]; then
echo "${help_message}"
exit 1;
fi
set -euo pipefail
mkdir -p ${logdir}
mkdir -p ${dumpdir}
dumpdir=$(perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' ${dumpdir} ${PWD})
for n in $(seq ${nj}); do
# the next command does nothing unless $dumpdir/storage/ exists, see
# utils/create_data_link.pl for more info.
utils/create_data_link.pl ${dumpdir}/feats.${n}.ark
done
if ${write_utt2num_frames}; then
write_num_frames_opt="--write-num-frames=ark,t:$dumpdir/utt2num_frames.JOB"
else
write_num_frames_opt=
fi
# split scp file
split_scps=""
for n in $(seq ${nj}); do
split_scps="$split_scps $logdir/feats.$n.scp"
done
utils/split_scp.pl ${scp} ${split_scps} || exit 1;
# dump features
if ${do_delta}; then
${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \
apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \
add-deltas ark:- ark:- \| \
copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \
--compress=${compress} --compression-method=2 ${write_num_frames_opt} \
ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \
|| exit 1
else
${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \
apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \
copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \
--compress=${compress} --compression-method=2 ${write_num_frames_opt} \
ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \
|| exit 1
fi
# concatenate scp files
for n in $(seq ${nj}); do
cat ${dumpdir}/feats.${n}.scp || exit 1;
done > ${dumpdir}/feats.scp || exit 1
if ${write_utt2num_frames}; then
for n in $(seq ${nj}); do
cat ${dumpdir}/utt2num_frames.${n} || exit 1;
done > ${dumpdir}/utt2num_frames || exit 1
rm ${dumpdir}/utt2num_frames.* 2>/dev/null
fi
# Write the filetype, this will be used for data2json.sh
echo ${filetype} > ${dumpdir}/filetype
# remove temp scps
rm ${logdir}/feats.*.scp 2>/dev/null
if [ ${verbose} -eq 1 ]; then
echo "Succeeded dumping features for training"
fi

@ -0,0 +1,84 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
from deepspeech.transform.transformation import Transformation
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
def get_parser():
parser = argparse.ArgumentParser(
description="convert feature to its shape",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
)
parser.add_argument(
"out",
nargs="?",
type=argparse.FileType("w"),
default=sys.stdout,
help="The output filename. " "If omitted, then output to sys.stdout",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
if args.preprocess_conf is not None:
preprocessing = Transformation(args.preprocess_conf)
logging.info("Apply preprocessing: {}".format(preprocessing))
else:
preprocessing = None
# There are no necessary for matrix without preprocessing,
# so change to file_reader_helper to return shape.
# This make sense only with filetype="hdf5".
for utt, mat in file_reader_helper(
args.rspecifier, args.filetype, return_shape=preprocessing is None
):
if preprocessing is not None:
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]
rate, mat = mat
mat = preprocessing(mat, uttid_list=utt)
shape_str = ",".join(map(str, mat.shape))
else:
if len(mat) == 2 and isinstance(mat[1], tuple):
# If data is sound file, Tuple[int, Tuple[int, ...]]
rate, mat = mat
shape_str = ",".join(map(str, mat))
args.out.write("{} {}\n".format(utt, shape_str))
if __name__ == "__main__":
main()

@ -0,0 +1,72 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=4
cmd=run.pl
verbose=0
filetype=""
preprocess_conf=""
# End configuration section.
help_message=$(cat << EOF
Usage: $0 [options] <input-scp> <output-scp> [<log-dir>]
e.g.: $0 data/train/feats.scp data/train/shape.scp data/train/log
Options:
--nj <nj> # number of parallel jobs
--cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs.
--filetype <mat|hdf5|sound.hdf5> # Specify the format of feats file
--preprocess-conf <json> # Apply preprocess to feats when creating shape.scp
--verbose <num> # Default: 0
EOF
)
echo "$0 $*" 1>&2 # Print the command line for logging
. parse_options.sh || exit 1;
if [ $# -lt 2 ] || [ $# -gt 3 ]; then
echo "${help_message}" 1>&2
exit 1;
fi
set -euo pipefail
scp=$1
outscp=$2
data=$(dirname ${scp})
if [ $# -eq 3 ]; then
logdir=$3
else
logdir=${data}/log
fi
mkdir -p ${logdir}
nj=$((nj<$(<"${scp}" wc -l)?nj:$(<"${scp}" wc -l)))
split_scps=""
for n in $(seq ${nj}); do
split_scps="${split_scps} ${logdir}/feats.${n}.scp"
done
utils/split_scp.pl ${scp} ${split_scps}
if [ -n "${preprocess_conf}" ]; then
preprocess_opt="--preprocess-conf ${preprocess_conf}"
else
preprocess_opt=""
fi
if [ -n "${filetype}" ]; then
filetype_opt="--filetype ${filetype}"
else
filetype_opt=""
fi
${cmd} JOB=1:${nj} ${logdir}/feat_to_shape.JOB.log \
feat-to-shape.py --verbose ${verbose} ${preprocess_opt} ${filetype_opt} \
scp:${logdir}/feats.JOB.scp ${logdir}/shape.JOB.scp
# concatenate the .scp files together.
for n in $(seq ${nj}); do
cat ${logdir}/shape.${n}.scp
done > ${outscp}
rm -f ${logdir}/feats.*.scp 2>/dev/null

@ -0,0 +1,289 @@
#!/usr/bin/env python3
# encoding: utf-8
import argparse
import codecs
import json
import logging
import sys
from distutils.util import strtobool
from io import open
from deepspeech.utils.cli_utils import get_commandline_args
PY2 = sys.version_info[0] == 2
sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout if PY2 else sys.stdout.buffer)
# Special types:
def shape(x):
"""Change str to List[int]
>>> shape('3,5')
[3, 5]
>>> shape(' [3, 5] ')
[3, 5]
"""
# x: ' [3, 5] ' -> '3, 5'
x = x.strip()
if x[0] == "[":
x = x[1:]
if x[-1] == "]":
x = x[:-1]
return list(map(int, x.split(",")))
def get_parser():
parser = argparse.ArgumentParser(
description="Given each file paths with such format as "
"<key>:<file>:<type>. type> can be omitted and the default "
'is "str". e.g. {} '
"--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape "
"--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape "
"--output-scps text:data/text shape:data/utt2text_shape:shape "
"--scps utt2spk:data/utt2spk".format(sys.argv[0]),
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--input-scps",
type=str,
nargs="*",
action="append",
default=[],
help="Json files for the inputs", )
parser.add_argument(
"--output-scps",
type=str,
nargs="*",
action="append",
default=[],
help="Json files for the outputs", )
parser.add_argument(
"--scps",
type=str,
nargs="+",
default=[],
help="The json files except for the input and outputs", )
parser.add_argument(
"--verbose", "-V", default=1, type=int, help="Verbose option")
parser.add_argument(
"--allow-one-column",
type=strtobool,
default=False,
help="Allow one column in input scp files. "
"In this case, the value will be empty string.", )
parser.add_argument(
"--out",
"-O",
type=str,
help="The output filename. "
"If omitted, then output to sys.stdout", )
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
args.scps = [args.scps]
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
# List[List[Tuple[str, str, Callable[[str], Any], str, str]]]
input_infos = []
output_infos = []
infos = []
for lis_list, key_scps_list in [
(input_infos, args.input_scps),
(output_infos, args.output_scps),
(infos, args.scps),
]:
for key_scps in key_scps_list:
lis = []
for key_scp in key_scps:
sps = key_scp.split(":")
if len(sps) == 2:
key, scp = sps
type_func = None
type_func_str = "none"
elif len(sps) == 3:
key, scp, type_func_str = sps
fail = False
try:
# type_func: Callable[[str], Any]
# e.g. type_func_str = "int" -> type_func = int
type_func = eval(type_func_str)
except Exception:
raise RuntimeError(
"Unknown type: {}".format(type_func_str))
if not callable(type_func):
raise RuntimeError(
"Unknown type: {}".format(type_func_str))
else:
raise RuntimeError(
"Format <key>:<filepath> "
"or <key>:<filepath>:<type> "
"e.g. feat:data/feat.scp "
"or shape:data/feat.scp:shape: {}".format(key_scp))
for item in lis:
if key == item[0]:
raise RuntimeError('The key "{}" is duplicated: {} {}'.
format(key, item[3], key_scp))
lis.append((key, scp, type_func, key_scp, type_func_str))
lis_list.append(lis)
# Open scp files
input_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
for il in input_infos]
output_fscps = [[open(i[1], "r", encoding="utf-8") for i in il]
for il in output_infos]
fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in infos]
# Note(kamo): What is done here?
# The final goal is creating a JSON file such as.
# {
# "utts": {
# "sample_id1": {(omitted)},
# "sample_id2": {(omitted)},
# ....
# }
# }
#
# To reduce memory usage, reading the input text files for each lines
# and writing JSON elements per samples.
if args.out is None:
out = sys.stdout
else:
out = open(args.out, "w", encoding="utf-8")
out.write('{\n "utts": {\n')
nutt = 0
while True:
nutt += 1
# List[List[str]]
input_lines = [[f.readline() for f in fl] for fl in input_fscps]
output_lines = [[f.readline() for f in fl] for fl in output_fscps]
lines = [[f.readline() for f in fl] for fl in fscps]
# Get the first line
concat = sum(input_lines + output_lines + lines, [])
if len(concat) == 0:
break
first = concat[0]
# Sanity check: Must be sorted by the first column and have same keys
count = 0
for ls_list in (input_lines, output_lines, lines):
for ls in ls_list:
for line in ls:
if line == "" or first == "":
if line != first:
concat = sum(input_infos + output_infos + infos, [])
raise RuntimeError("The number of lines mismatch "
'between: "{}" and "{}"'.format(
concat[0][1],
concat[count][1]))
elif line.split()[0] != first.split()[0]:
concat = sum(input_infos + output_infos + infos, [])
raise RuntimeError(
"The keys are mismatch at {}th line "
'between "{}" and "{}":\n>>> {}\n>>> {}'.format(
nutt,
concat[0][1],
concat[count][1],
first.rstrip(),
line.rstrip(), ))
count += 1
# The end of file
if first == "":
if nutt != 1:
out.write("\n")
break
if nutt != 1:
out.write(",\n")
entry = {}
for inout, _lines, _infos in [
("input", input_lines, input_infos),
("output", output_lines, output_infos),
("other", lines, infos),
]:
lis = []
for idx, (line_list, info_list) in enumerate(
zip(_lines, _infos), 1):
if inout == "input":
d = {"name": "input{}".format(idx)}
elif inout == "output":
d = {"name": "target{}".format(idx)}
else:
d = {}
# info_list: List[Tuple[str, str, Callable]]
# line_list: List[str]
for line, info in zip(line_list, info_list):
sps = line.split(None, 1)
if len(sps) < 2:
if not args.allow_one_column:
raise RuntimeError(
"Format error {}th line in {}: "
' Expecting "<key> <value>":\n>>> {}'.format(
nutt, info[1], line))
uttid = sps[0]
value = ""
else:
uttid, value = sps
key = info[0]
type_func = info[2]
value = value.rstrip()
if type_func is not None:
try:
# type_func: Callable[[str], Any]
value = type_func(value)
except Exception:
logging.error(
'"{}" is an invalid function '
"for the {} th line in {}: \n>>> {}".format(
info[4], nutt, info[1], line))
raise
d[key] = value
lis.append(d)
if inout != "other":
entry[inout] = lis
else:
# If key == 'other'. only has the first item
entry.update(lis[0])
entry = json.dumps(
entry,
indent=4,
ensure_ascii=False,
sort_keys=True,
separators=(",", ": "))
# Add indent
indent = " " * 2
entry = ("\n" + indent).join(entry.split("\n"))
uttid = first.split()[0]
out.write(' "{}": {}'.format(uttid, entry))
out.write(" }\n}\n")
logging.info("{} entries in {}".format(nutt, out.name))

@ -0,0 +1,59 @@
#!/usr/bin/env bash
# koried, 10/29/2012
# Reduce a data set based on a list of turn-ids
help_message="usage: $0 srcdir turnlist destdir"
if [ $1 == "--help" ]; then
echo "${help_message}"
exit 0;
fi
if [ $# != 3 ]; then
echo "${help_message}"
exit 1;
fi
srcdir=$1
reclist=$2
destdir=$3
if [ ! -f ${srcdir}/utt2spk ]; then
echo "$0: no such file $srcdir/utt2spk"
exit 1;
fi
function do_filtering {
# assumes the utt2spk and spk2utt files already exist.
[ -f ${srcdir}/feats.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/feats.scp >${destdir}/feats.scp
[ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/wav.scp >${destdir}/wav.scp
[ -f ${srcdir}/text ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/text >${destdir}/text
[ -f ${srcdir}/utt2num_frames ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/utt2num_frames >${destdir}/utt2num_frames
[ -f ${srcdir}/spk2gender ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/spk2gender >${destdir}/spk2gender
[ -f ${srcdir}/cmvn.scp ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/cmvn.scp >${destdir}/cmvn.scp
if [ -f ${srcdir}/segments ]; then
utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/segments >${destdir}/segments
awk '{print $2;}' ${destdir}/segments | sort | uniq > ${destdir}/reco # recordings.
# The next line would override the command above for wav.scp, which would be incorrect.
[ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/reco <${srcdir}/wav.scp >${destdir}/wav.scp
[ -f ${srcdir}/reco2file_and_channel ] && \
utils/filter_scp.pl ${destdir}/reco <${srcdir}/reco2file_and_channel >${destdir}/reco2file_and_channel
# Filter the STM file for proper sclite scoring (this will also remove the comments lines)
[ -f ${srcdir}/stm ] && utils/filter_scp.pl ${destdir}/reco < ${srcdir}/stm > ${destdir}/stm
rm ${destdir}/reco
fi
srcutts=$(wc -l < ${srcdir}/utt2spk)
destutts=$(wc -l < ${destdir}/utt2spk)
echo "Reduced #utt from $srcutts to $destutts"
}
mkdir -p ${destdir}
# filter the utt2spk based on the set of recordings
utils/filter_scp.pl ${reclist} < ${srcdir}/utt2spk > ${destdir}/utt2spk
utils/utt2spk_to_spk2utt.pl < ${destdir}/utt2spk > ${destdir}/spk2utt
do_filtering;

@ -0,0 +1,62 @@
#!/usr/bin/env bash
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
. ./path.sh
maxframes=2000
minframes=10
maxchars=200
minchars=0
nlsyms=""
no_feat=false
trans_type=char
help_message="usage: $0 olddatadir newdatadir"
. utils/parse_options.sh || exit 1;
if [ $# != 2 ]; then
echo "${help_message}"
exit 1;
fi
sdir=$1
odir=$2
mkdir -p ${odir}/tmp
if [ ${no_feat} = true ]; then
# for machine translation
cut -d' ' -f 1 ${sdir}/text > ${odir}/tmp/reclist1
else
echo "extract utterances having less than $maxframes or more than $minframes frames"
utils/data/get_utt2num_frames.sh ${sdir}
< ${sdir}/utt2num_frames awk -v maxframes="$maxframes" '{ if ($2 < maxframes) print }' \
| awk -v minframes="$minframes" '{ if ($2 > minframes) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist1
fi
echo "extract utterances having less than $maxchars or more than $minchars characters"
# counting number of chars. Use (NF - 1) instead of NF to exclude the utterance ID column
if [ -z ${nlsyms} ]; then
text2token.py -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \
| awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \
| awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist2
else
text2token.py -l ${nlsyms} -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \
| awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \
| awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist2
fi
# extract common lines
comm -12 <(sort ${odir}/tmp/reclist1) <(sort ${odir}/tmp/reclist2) > ${odir}/tmp/reclist
reduce_data_dir.sh ${sdir} ${odir}/tmp/reclist ${odir}
utils/fix_data_dir.sh ${odir}
oldnum=$(wc -l ${sdir}/feats.scp | awk '{print $1}')
newnum=$(wc -l ${odir}/feats.scp | awk '{print $1}')
echo "change from $oldnum to $newnum"

@ -0,0 +1,129 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--nchar",
"-n",
default=1,
type=int,
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2", )
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns")
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.", )
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "phn"],
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
If trans_type is char,
read from SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """, )
return parser
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin
if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout
if is_python2 else sys.stdout.buffer)
line = f.readline()
n = args.nchar
while line:
x = line.split()
print(" ".join(x[:args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols:])
# get all matched positions
match_pos = []
for r in rs:
i = 0
while i >= 0:
m = r.search(a, i)
if m:
match_pos.append([m.start(), m.end()])
i = m.end()
else:
break
if args.trans_type == "phn":
a = a.split(" ")
else:
if len(match_pos) > 0:
chars = []
i = 0
while i < len(a):
start_pos, end_pos = exist_or_not(i, match_pos)
if start_pos is not None:
chars.append(a[start_pos:end_pos])
i = end_pos
else:
chars.append(a[i])
i += 1
a = chars
a = [a[j:j + n] for j in range(0, len(a), n)]
a_flat = []
for z in a:
a_flat.append("".join(z))
a_chars = [z.replace(" ", args.space) for z in a_flat]
if args.trans_type == "phn":
a_chars = [z.replace("sil", args.space) for z in a_chars]
print(" ".join(a_chars))
line = f.readline()
if __name__ == "__main__":
main()
Loading…
Cancel
Save